mirror of
https://github.com/clearlinux/WALinuxAgent.git
synced 2026-06-16 02:45:59 +00:00
Fix pylint warnings
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
###############################################################################
|
||||
# Set default behavior to automatically normalize line endings.
|
||||
###############################################################################
|
||||
* text=auto
|
||||
|
||||
###############################################################################
|
||||
# Set default behavior for command prompt diff.
|
||||
#
|
||||
# This is need for earlier builds of msysgit that does not have it on by
|
||||
# default for csharp files.
|
||||
# Note: This is only used by command line
|
||||
###############################################################################
|
||||
#*.cs diff=csharp
|
||||
|
||||
###############################################################################
|
||||
# Set the merge driver for project and solution files
|
||||
#
|
||||
# Merging from the command prompt will add diff markers to the files if there
|
||||
# are conflicts (Merging from VS is not affected by the settings below, in VS
|
||||
# the diff markers are never inserted). Diff markers may cause the following
|
||||
# file extensions to fail to load in VS. An alternative would be to treat
|
||||
# these files as binary and thus will always conflict and require user
|
||||
# intervention with every merge. To do so, just uncomment the entries below
|
||||
###############################################################################
|
||||
#*.sln merge=binary
|
||||
#*.csproj merge=binary
|
||||
#*.vbproj merge=binary
|
||||
#*.vcxproj merge=binary
|
||||
#*.vcproj merge=binary
|
||||
#*.dbproj merge=binary
|
||||
#*.fsproj merge=binary
|
||||
#*.lsproj merge=binary
|
||||
#*.wixproj merge=binary
|
||||
#*.modelproj merge=binary
|
||||
#*.sqlproj merge=binary
|
||||
#*.wwaproj merge=binary
|
||||
|
||||
###############################################################################
|
||||
# behavior for image files
|
||||
#
|
||||
# image files are treated as binary by default.
|
||||
###############################################################################
|
||||
#*.jpg binary
|
||||
#*.png binary
|
||||
#*.gif binary
|
||||
|
||||
###############################################################################
|
||||
# diff behavior for common document formats
|
||||
#
|
||||
# Convert binary document formats to text before diffing them. This feature
|
||||
# is only available from the command line. Turn it on by uncommenting the
|
||||
# entries below.
|
||||
###############################################################################
|
||||
#*.doc diff=astextplain
|
||||
#*.DOC diff=astextplain
|
||||
#*.docx diff=astextplain
|
||||
#*.DOCX diff=astextplain
|
||||
#*.dot diff=astextplain
|
||||
#*.DOT diff=astextplain
|
||||
#*.pdf diff=astextplain
|
||||
#*.PDF diff=astextplain
|
||||
#*.rtf diff=astextplain
|
||||
#*.RTF diff=astextplain
|
||||
@@ -54,3 +54,6 @@ docs/_build/
|
||||
target/
|
||||
|
||||
waagentc
|
||||
*.pyproj
|
||||
*.sln
|
||||
*.suo
|
||||
|
||||
Executable → Regular
+5
-10
@@ -1,4 +1,3 @@
|
||||
#!/bin/bash
|
||||
# Copyright 2014 Microsoft Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,14 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Implements parts of RFC 2131, 1541, 1497 and
|
||||
# http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx
|
||||
# http://msdn.microsoft.com/en-us/library/cc227259%28PROT.13%29.aspx
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
|
||||
read Line
|
||||
import azurelinuxagent.agent as agent
|
||||
|
||||
if [ $Line == "y" ] ; then
|
||||
exit 0
|
||||
else
|
||||
exit 1
|
||||
fi
|
||||
print "hehe"
|
||||
agent.main()
|
||||
+72
-45
@@ -17,38 +17,48 @@
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
|
||||
"""
|
||||
Module agent
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import traceback
|
||||
import threading
|
||||
import subprocess
|
||||
import azurelinuxagent.logger as logger
|
||||
from azurelinuxagent.metadata import GuestAgentName, GuestAgentLongVersion, \
|
||||
DistroName, DistroVersion, DistroFullName
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.handler import Handlers
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
from azurelinuxagent.metadata import AGENT_NAME, agent_long_version, \
|
||||
DISTRO_NAME, DISTRO_VERSION
|
||||
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
from azurelinuxagent.handler import HANDLERS
|
||||
|
||||
|
||||
def Init(verbose):
|
||||
Handlers.initHandler.init(verbose)
|
||||
|
||||
def Run():
|
||||
Handlers.runHandler.run()
|
||||
|
||||
def Deprovision(force=False, deluser=False):
|
||||
Handlers.deprovisionHandler.deprovision(force=force, deluser=deluser)
|
||||
|
||||
def ParseArgs(sysArgv):
|
||||
def init(verbose):
|
||||
"""
|
||||
Initialize agent running environment.
|
||||
"""
|
||||
HANDLERS.init_handler.init(verbose)
|
||||
|
||||
def run():
|
||||
"""
|
||||
Run agent daemon
|
||||
"""
|
||||
HANDLERS.main_handler.run()
|
||||
|
||||
def deprovision(force=False, deluser=False):
|
||||
"""
|
||||
Run deprovision command
|
||||
"""
|
||||
HANDLERS.deprovision_handler.deprovision(force=force, deluser=deluser)
|
||||
|
||||
def parse_args(sys_args):
|
||||
"""
|
||||
Parse command line arguments
|
||||
"""
|
||||
cmd = "help"
|
||||
force = False
|
||||
verbose = False
|
||||
for a in sysArgv:
|
||||
if re.match("^([-/]*)deprovision\+user", a):
|
||||
for a in sys_args:
|
||||
if re.match("^([-/]*)deprovision\\+user", a):
|
||||
cmd = "deprovision+user"
|
||||
elif re.match("^([-/]*)deprovision", a):
|
||||
cmd = "deprovision"
|
||||
@@ -64,48 +74,65 @@ def ParseArgs(sysArgv):
|
||||
verbose = True
|
||||
elif re.match("^([-/]*)force", a):
|
||||
force = True
|
||||
elif re.match("^([-/]*)(help|usage|\?)", a):
|
||||
elif re.match("^([-/]*)(help|usage|\\?)", a):
|
||||
cmd = "help"
|
||||
else:
|
||||
cmd = "help"
|
||||
break
|
||||
return cmd, force, verbose
|
||||
|
||||
def Version():
|
||||
print("{0} running on {1} {2}".format(GuestAgentLongVersion, DistroName,
|
||||
DistroVersion))
|
||||
def Usage():
|
||||
def version():
|
||||
"""
|
||||
Show agent version
|
||||
"""
|
||||
print("{0} running on {1} {2}".format(agent_long_version, DISTRO_NAME,
|
||||
DISTRO_VERSION))
|
||||
def usage():
|
||||
"""
|
||||
Show agent usage
|
||||
"""
|
||||
print("")
|
||||
print(("usage: {0} [-verbose] [-force] [-help]"
|
||||
"-deprovision[+user]|-register-service|-version|-daemon|-start]"
|
||||
"").format(sys.argv[0]))
|
||||
print("")
|
||||
|
||||
def Start():
|
||||
def start():
|
||||
"""
|
||||
Start agent daemon in a background process and set stdout/stderr to
|
||||
/dev/null
|
||||
"""
|
||||
devnull = open(os.devnull, 'w')
|
||||
subprocess.Popen([sys.argv[0], '-daemon'], stdout=devnull, stderr=devnull)
|
||||
|
||||
def RegisterService():
|
||||
print "Register {0} service".format(GuestAgentName)
|
||||
OSUtil.RegisterAgentService()
|
||||
print "Start {0} service".format(GuestAgentName)
|
||||
OSUtil.StartAgentService()
|
||||
def register_service():
|
||||
"""
|
||||
Register agent as a service
|
||||
"""
|
||||
print "Register {0} service".format(AGENT_NAME)
|
||||
OSUTIL.register_agent_service()
|
||||
print "Start {0} service".format(AGENT_NAME)
|
||||
OSUTIL.start_agent_service()
|
||||
|
||||
def Main():
|
||||
command, force, verbose = ParseArgs(sys.argv[1:])
|
||||
def main():
|
||||
"""
|
||||
Parse command line arguments, exit with usage() on error.
|
||||
Invoke different methods according to different command
|
||||
"""
|
||||
command, force, verbose = parse_args(sys.argv[1:])
|
||||
if command == "version":
|
||||
Version()
|
||||
version()
|
||||
elif command == "help":
|
||||
Usage()
|
||||
else:
|
||||
Init(verbose)
|
||||
usage()
|
||||
else:
|
||||
init(verbose)
|
||||
if command == "deprovision+user":
|
||||
Deprovision(force, deluser=True)
|
||||
deprovision(force, deluser=True)
|
||||
elif command == "deprovision":
|
||||
Deprovision(force, deluser=False)
|
||||
deprovision(force, deluser=False)
|
||||
elif command == "start":
|
||||
Start()
|
||||
start()
|
||||
elif command == "register-service":
|
||||
RegisterService()
|
||||
register_service()
|
||||
elif command == "daemon":
|
||||
Run()
|
||||
run()
|
||||
|
||||
+45
-27
@@ -17,9 +17,12 @@
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
|
||||
"""
|
||||
Module conf loads and parses configuration file
|
||||
"""
|
||||
import os
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
from azurelinuxagent.exception import *
|
||||
from azurelinuxagent.exception import AgentConfigError
|
||||
|
||||
class ConfigurationProvider(object):
|
||||
"""
|
||||
@@ -40,52 +43,67 @@ class ConfigurationProvider(object):
|
||||
else:
|
||||
self.values[parts[0]] = None
|
||||
|
||||
def get(self, key, defaultValue=None):
|
||||
def get(self, key, default_val=None):
|
||||
val = self.values.get(key)
|
||||
return val if val is not None else defaultValue
|
||||
return val if val is not None else default_val
|
||||
|
||||
def getSwitch(self, key, defaultValue=False):
|
||||
def get_switch(self, key, default_val=False):
|
||||
val = self.values.get(key)
|
||||
if val is not None and val.lower() == 'y':
|
||||
return True
|
||||
elif val is not None and val.lower() == 'n':
|
||||
return False
|
||||
return defaultValue
|
||||
return default_val
|
||||
|
||||
def getInt(self, key, defaultValue=-1):
|
||||
def get_int(self, key, default_val=-1):
|
||||
try:
|
||||
return int(self.values.get(key))
|
||||
except:
|
||||
return defaultValue
|
||||
except TypeError:
|
||||
return default_val
|
||||
except ValueError:
|
||||
return default_val
|
||||
|
||||
|
||||
__Config__ = ConfigurationProvider()
|
||||
__config__ = ConfigurationProvider()
|
||||
|
||||
def LoadConfiguration(confFilePath, conf=__Config__):
|
||||
if os.path.isfile(confFilePath) == False:
|
||||
raise AgentConfigError("Missing configuration in {0}".format(confFilePath))
|
||||
def load_conf(conf_file_path, conf=__config__):
|
||||
"""
|
||||
Load conf file from: conf_file_path
|
||||
"""
|
||||
if os.path.isfile(conf_file_path) == False:
|
||||
raise AgentConfigError(("Missing configuration in {0}"
|
||||
"").format(conf_file_path))
|
||||
try:
|
||||
content = fileutil.GetFileContents(confFilePath)
|
||||
content = fileutil.read_file(conf_file_path)
|
||||
conf.load(content)
|
||||
except IOError, e:
|
||||
raise AgentConfigError("Failed to load conf file:{0}".format(confFilePath))
|
||||
except IOError as err:
|
||||
raise AgentConfigError(("Failed to load conf file:{0}, {1}"
|
||||
"").format(conf_file_path, err))
|
||||
|
||||
def Get(key, defaultValue=None, conf=__Config__):
|
||||
def get(key, default_val=None, conf=__config__):
|
||||
"""
|
||||
Get option value by key, return default_val if not found
|
||||
"""
|
||||
if conf is not None:
|
||||
return conf.get(key, defaultValue)
|
||||
return conf.get(key, default_val)
|
||||
else:
|
||||
return defaultValue
|
||||
|
||||
def GetSwitch(key, defaultValue=None, conf=__Config__):
|
||||
if conf is not None:
|
||||
return conf.getSwitch(key, defaultValue)
|
||||
else:
|
||||
return defaultValue
|
||||
return default_val
|
||||
|
||||
def GetInt(key, defaultValue=None, conf=__Config__):
|
||||
def get_switch(key, default_val=None, conf=__config__):
|
||||
"""
|
||||
Get bool option value by key, return default_val if not found
|
||||
"""
|
||||
if conf is not None:
|
||||
return conf.getInt(key, defaultValue)
|
||||
return conf.get_switch(key, default_val)
|
||||
else:
|
||||
return defaultValue
|
||||
return default_val
|
||||
|
||||
def get_int(key, default_val=None, conf=__config__):
|
||||
"""
|
||||
Get int option value by key, return default_val if not found
|
||||
"""
|
||||
if conf is not None:
|
||||
return conf.get_int(key, default_val)
|
||||
else:
|
||||
return default_val
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
|
||||
from azurelinuxagent.metadata import DistroName, DistroVersion
|
||||
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION
|
||||
import azurelinuxagent.distro.redhat.loader as redhat
|
||||
|
||||
def GetOSUtil():
|
||||
return redhat.GetOSUtil()
|
||||
|
||||
def get_osutil():
|
||||
return redhat.get_osutil()
|
||||
|
||||
|
||||
@@ -21,10 +21,10 @@ import azurelinuxagent.utils.fileutil as fileutil
|
||||
from azurelinuxagent.distro.default.deprovision import DeprovisionHandler, DeprovisionAction
|
||||
|
||||
class CoreOSDeprovisionHandler(DeprovisionHandler):
|
||||
def setUp(self, deluser):
|
||||
warnings, actions = super(CoreOSDeprovisionHandler, self).setUp(deluser)
|
||||
def setup(self, deluser):
|
||||
warnings, actions = super(CoreOSDeprovisionHandler, self).setup(deluser)
|
||||
warnings.append("WARNING! /etc/machine-id will be removed.")
|
||||
filesToDel = ['/etc/machine-id']
|
||||
actions.append(DeprovisionAction(fileutil.RemoveFiles, filesToDel))
|
||||
files_to_del = ['/etc/machine-id']
|
||||
actions.append(DeprovisionAction(fileutil.rm_files, files_to_del))
|
||||
return warnings, actions
|
||||
|
||||
|
||||
@@ -23,5 +23,5 @@ from azurelinuxagent.distro.default.handlerFactory import DefaultHandlerFactory
|
||||
class CoreOSHandlerFactory(DefaultHandlerFactory):
|
||||
def __init__(self):
|
||||
super(CoreOSHandlerFactory, self).__init__()
|
||||
self.deprovisionHandler = CoreOSDeprovisionHandler()
|
||||
self.deprovision_handler = CoreOSDeprovisionHandler()
|
||||
|
||||
|
||||
@@ -18,11 +18,11 @@
|
||||
#
|
||||
|
||||
|
||||
def GetOSUtil():
|
||||
def get_osutil():
|
||||
from azurelinuxagent.distro.coreos.osutil import CoreOSUtil
|
||||
return CoreOSUtil()
|
||||
|
||||
def GetHandlers():
|
||||
def get_handlers():
|
||||
from azurelinuxagent.distro.coreos.handlerFactory import CoreOSHandlerFactory
|
||||
return CoreOSHandlerFactory()
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class CoreOSUtil(DefaultOSUtil):
|
||||
super(CoreOSUtil, self).__init__()
|
||||
self.waagent_path='/usr/share/oem/bin/waagent'
|
||||
self.python_path='/usr/share/oem/python/bin'
|
||||
self.configPath = '/usr/share/oem/waagent.conf'
|
||||
self.conf_path = '/usr/share/oem/waagent.conf'
|
||||
if 'PATH' in os.environ:
|
||||
path = "{0}:{1}".format(os.environ['PATH'], self.python_path)
|
||||
else:
|
||||
@@ -51,40 +51,40 @@ class CoreOSUtil(DefaultOSUtil):
|
||||
py_path = self.waagent_path
|
||||
os.environ['PYTHONPATH'] = py_path
|
||||
|
||||
def IsSysUser(self, userName):
|
||||
def is_sys_user(self, username):
|
||||
#User 'core' is not a sysuser
|
||||
if userName == 'core':
|
||||
if username == 'core':
|
||||
return False
|
||||
return super(CoreOSUtil, self).IsSysUser(userName)
|
||||
return super(CoreOSUtil, self).IsSysUser(username)
|
||||
|
||||
def IsDhcpEnabled(self):
|
||||
def is_dhcp_enabled(self):
|
||||
return True
|
||||
|
||||
def StartNetwork(self) :
|
||||
return shellutil.Run("systemctl start systemd-networkd", chk_err=False)
|
||||
|
||||
def RestartInterface(self, iface):
|
||||
shellutil.Run("systemctl restart systemd-networkd")
|
||||
|
||||
def RestartSshService(self):
|
||||
return shellutil.Run("systemctl restart sshd", chk_err=False)
|
||||
def start_network(self) :
|
||||
return shellutil.run("systemctl start systemd-networkd", chk_err=False)
|
||||
|
||||
def StopDhcpService(self):
|
||||
return shellutil.Run("systemctl stop systemd-networkd", chk_err=False)
|
||||
def restart_if(self, iface):
|
||||
shellutil.run("systemctl restart systemd-networkd")
|
||||
|
||||
def StartDhcpService(self):
|
||||
return shellutil.Run("systemctl start systemd-networkd", chk_err=False)
|
||||
def restart_ssh_service(self):
|
||||
return shellutil.run("systemctl restart sshd", chk_err=False)
|
||||
|
||||
def StartAgentService(self):
|
||||
return shellutil.Run("systemctl start wagent", chk_err=False)
|
||||
def stop_dhcp_service(self):
|
||||
return shellutil.run("systemctl stop systemd-networkd", chk_err=False)
|
||||
|
||||
def StopAgentService(self):
|
||||
return shellutil.Run("systemctl stop wagent", chk_err=False)
|
||||
|
||||
def GetDhcpProcessId(self):
|
||||
ret= shellutil.RunGetOutput("pidof systemd-networkd")
|
||||
def start_dhcp_service(self):
|
||||
return shellutil.run("systemctl start systemd-networkd", chk_err=False)
|
||||
|
||||
def start_agent_service(self):
|
||||
return shellutil.run("systemctl start wagent", chk_err=False)
|
||||
|
||||
def stop_agent_service(self):
|
||||
return shellutil.run("systemctl stop wagent", chk_err=False)
|
||||
|
||||
def get_dhcp_pid(self):
|
||||
ret= shellutil.run_get_output("pidof systemd-networkd")
|
||||
return ret[1] if ret[0] == 0 else None
|
||||
|
||||
def TranslateCustomData(self, data):
|
||||
|
||||
def decode_customdata(self, data):
|
||||
return base64.b64decode(data)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
#
|
||||
|
||||
|
||||
def GetOSUtil():
|
||||
def get_osutil():
|
||||
from azurelinuxagent.distro.debian.osutil import DebianOSUtil
|
||||
return DebianOSUtil()
|
||||
|
||||
|
||||
@@ -30,18 +30,18 @@ import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
import azurelinuxagent.utils.textutil as textutil
|
||||
from azurelinuxagent.distro.default.osutil import OSUtil
|
||||
from azurelinuxagent.distro.default.osutil import DefaultOSUtil
|
||||
|
||||
class DebianOSUtil(OSUtil):
|
||||
class DebianOSUtil(DefaultOSUtil):
|
||||
def __init__(self):
|
||||
super(DebianOSUtil, self).__init__()
|
||||
|
||||
def RestartSshService(self):
|
||||
return shellutil.Run("service sshd restart", chk_err=False)
|
||||
def restart_ssh_service(self):
|
||||
return shellutil.run("service sshd restart", chk_err=False)
|
||||
|
||||
def StopAgentService(self):
|
||||
return shellutil.Run("service azurelinuxagent stop", chk_err=False)
|
||||
def stop_agent_service(self):
|
||||
return shellutil.run("service azurelinuxagent stop", chk_err=False)
|
||||
|
||||
def StartAgentService(self):
|
||||
return shellutil.Run("service azurelinuxagent start", chk_err=False)
|
||||
def start_agent_service(self):
|
||||
return shellutil.run("service azurelinuxagent start", chk_err=False)
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@
|
||||
#
|
||||
|
||||
import azurelinuxagent.conf as conf
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
import azurelinuxagent.protocol as prot
|
||||
import azurelinuxagent.protocol.ovfenv as ovf
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
@@ -35,74 +35,74 @@ class DeprovisionAction(object):
|
||||
|
||||
class DeprovisionHandler(object):
|
||||
|
||||
def deleteRootPassword(self, warnings, actions):
|
||||
def del_root_password(self, warnings, actions):
|
||||
warnings.append("WARNING! root password will be disabled. "
|
||||
"You will not be able to login as root.")
|
||||
|
||||
actions.append(DeprovisionAction(OSUtil.DeleteRootPassword))
|
||||
|
||||
def deleteUser(self, warnings, actions):
|
||||
|
||||
actions.append(DeprovisionAction(OSUTIL.del_root_password))
|
||||
|
||||
def del_user(self, warnings, actions):
|
||||
|
||||
try:
|
||||
ovfenv = ovf.GetOvfEnv()
|
||||
ovfenv = ovf.get_ovf_env()
|
||||
except prot.ProtocolError:
|
||||
warnings.append("WARNING! ovf-env.xml is not found.")
|
||||
warnings.append("WARNING! Skip delete user.")
|
||||
return
|
||||
|
||||
userName = ovfenv.getUserName()
|
||||
username = ovfenv.get_username()
|
||||
warnings.append(("WARNING! {0} account and entire home directory "
|
||||
"will be deleted.").format(userName))
|
||||
actions.append(DeprovisionAction(OSUtil.DeleteAccount, [userName]))
|
||||
"will be deleted.").format(username))
|
||||
actions.append(DeprovisionAction(OSUTIL.del_account, [username]))
|
||||
|
||||
|
||||
def regenerateHostKeyPair(self, warnings, actions):
|
||||
def regen_ssh_host_key(self, warnings, actions):
|
||||
warnings.append("WARNING! All SSH host key pairs will be deleted.")
|
||||
actions.append(DeprovisionAction(OSUtil.SetHostname,
|
||||
actions.append(DeprovisionAction(OSUTIL.set_hostname,
|
||||
['localhost.localdomain']))
|
||||
actions.append(DeprovisionAction(shellutil.Run,
|
||||
actions.append(DeprovisionAction(shellutil.run,
|
||||
['rm -f /etc/ssh/ssh_host_*key*']))
|
||||
|
||||
def stopAgentService(self, warnings, actions):
|
||||
|
||||
def stop_agent_service(self, warnings, actions):
|
||||
warnings.append("WARNING! The waagent service will be stopped.")
|
||||
actions.append(DeprovisionAction(OSUtil.StopAgentService))
|
||||
actions.append(DeprovisionAction(OSUTIL.stop_agent_service))
|
||||
|
||||
def deleteFiles(self, warnings, actions):
|
||||
filesToDel = ['/root/.bash_history', '/var/log/waagent.log']
|
||||
actions.append(DeprovisionAction(fileutil.RemoveFiles, filesToDel))
|
||||
def del_files(self, warnings, actions):
|
||||
files_to_del = ['/root/.bash_history', '/var/log/waagent.log']
|
||||
actions.append(DeprovisionAction(fileutil.rm_files, files_to_del))
|
||||
|
||||
def deleteDhcpLease(self, warnings, actions):
|
||||
def del_dhcp_lease(self, warnings, actions):
|
||||
warnings.append("WARNING! Cached DHCP leases will be deleted.")
|
||||
dirsToDel = ["/var/lib/dhclient", "/var/lib/dhcpcd", "/var/lib/dhcp"]
|
||||
actions.append(DeprovisionAction(fileutil.CleanupDirs, dirsToDel))
|
||||
dirs_to_del = ["/var/lib/dhclient", "/var/lib/dhcpcd", "/var/lib/dhcp"]
|
||||
actions.append(DeprovisionAction(fileutil.rm_dirs, dirs_to_del))
|
||||
|
||||
def deleteLibDir(self, warnings, actions):
|
||||
dirsToDel = [OSUtil.GetLibDir()]
|
||||
actions.append(DeprovisionAction(fileutil.CleanupDirs, dirsToDel))
|
||||
def del_lib_dir(self, warnings, actions):
|
||||
dirs_to_del = [OSUTIL.get_lib_dir()]
|
||||
actions.append(DeprovisionAction(fileutil.rm_dirs, dirs_to_del))
|
||||
|
||||
def setUp(self, deluser):
|
||||
def setup(self, deluser):
|
||||
warnings = []
|
||||
actions = []
|
||||
|
||||
self.stopAgentService(warnings, actions)
|
||||
if conf.GetSwitch("Provisioning.RegenerateSshHostkey", False):
|
||||
self.regenerateHostKeyPair(warnings, actions)
|
||||
|
||||
self.deleteDhcpLease(warnings, actions)
|
||||
self.stop_agent_service(warnings, actions)
|
||||
if conf.get_switch("Provisioning.RegenerateSshHostkey", False):
|
||||
self.regen_ssh_host_key(warnings, actions)
|
||||
|
||||
if conf.GetSwitch("Provisioning.DeleteRootPassword", False):
|
||||
self.deleteRootPassword(warnings, actions)
|
||||
self.del_dhcp_lease(warnings, actions)
|
||||
|
||||
self.deleteLibDir(warnings, actions)
|
||||
self.deleteFiles(warnings, actions)
|
||||
if conf.get_switch("Provisioning.DeleteRootPassword", False):
|
||||
self.del_root_password(warnings, actions)
|
||||
|
||||
self.del_lib_dir(warnings, actions)
|
||||
self.del_files(warnings, actions)
|
||||
|
||||
if deluser:
|
||||
self.deleteUser(warnings, actions)
|
||||
|
||||
self.del_user(warnings, actions)
|
||||
|
||||
return warnings, actions
|
||||
|
||||
|
||||
def deprovision(self, force=False, deluser=False):
|
||||
warnings, actions = self.setUp(deluser)
|
||||
warnings, actions = self.setup(deluser)
|
||||
for warning in warnings:
|
||||
print warning
|
||||
|
||||
@@ -110,8 +110,8 @@ class DeprovisionHandler(object):
|
||||
confirm = raw_input("Do you want to proceed (y/n)")
|
||||
if not confirm.lower().startswith('y'):
|
||||
return
|
||||
|
||||
|
||||
for action in actions:
|
||||
action.invoke()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -20,171 +20,172 @@ import socket
|
||||
import array
|
||||
import time
|
||||
import azurelinuxagent.logger as logger
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
from azurelinuxagent.exception import AgentNetworkError
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
from azurelinuxagent.utils.textutil import *
|
||||
|
||||
WireServerAddrFile="WireServer"
|
||||
WIRE_SERVER_ADDR_FILE_NAME="WireServer"
|
||||
|
||||
class DhcpHandler(object):
|
||||
def __init__(self):
|
||||
self.endpoint = None
|
||||
self.gateway = None
|
||||
self.routes = None
|
||||
|
||||
def waitForNetwork(self):
|
||||
ipv4 = OSUtil.GetIpv4Address()
|
||||
def wait_for_network(self):
|
||||
ipv4 = OSUTIL.get_ip4_addr()
|
||||
while ipv4 == '' or ipv4 == '0.0.0.0':
|
||||
logger.Info("Waiting for network.")
|
||||
logger.info("Waiting for network.")
|
||||
time.sleep(10)
|
||||
OSUtil.StartNetwork()
|
||||
ipv4 = OSUtil.GetIpv4Address()
|
||||
OSUTIL.start_network()
|
||||
ipv4 = OSUTIL.get_ip4_addr()
|
||||
|
||||
def probe(self):
|
||||
logger.Info("Send dhcp request")
|
||||
self.waitForNetwork()
|
||||
macAddress = OSUtil.GetMacAddress()
|
||||
req = BuildDhcpRequest(macAddress)
|
||||
resp = SendDhcpRequest(req)
|
||||
endpoint, gateway, routes = ParseDhcpResponse(resp)
|
||||
logger.info("Send dhcp request")
|
||||
self.wait_for_network()
|
||||
mac_addr = OSUTIL.get_mac_addr()
|
||||
req = build_dhcp_request(mac_addr)
|
||||
resp = send_dhcp_request(req)
|
||||
endpoint, gateway, routes = parse_dhcp_resp(resp)
|
||||
self.endpoint = endpoint
|
||||
logger.Info("Wire server endpoint:{0}", endpoint)
|
||||
logger.Info("Gateway:{0}", gateway)
|
||||
logger.Info("Routes:{0}", routes)
|
||||
logger.info("Wire server endpoint:{0}", endpoint)
|
||||
logger.info("Gateway:{0}", gateway)
|
||||
logger.info("Routes:{0}", routes)
|
||||
if endpoint is not None:
|
||||
path = os.path.join(OSUtil.GetLibDir(), WireServerAddrFile)
|
||||
fileutil.SetFileContents(path, endpoint)
|
||||
path = os.path.join(OSUTIL.get_lib_dir(), WIRE_SERVER_ADDR_FILE_NAME)
|
||||
fileutil.write_file(path, endpoint)
|
||||
self.gateway = gateway
|
||||
self.routes = routes
|
||||
self.configRoutes()
|
||||
self.conf_routes()
|
||||
|
||||
def getEndpoint(self):
|
||||
def get_endpoint(self):
|
||||
return self.endpoint
|
||||
|
||||
def configRoutes(self):
|
||||
logger.Info("Configure routes")
|
||||
def conf_routes(self):
|
||||
logger.info("Configure routes")
|
||||
#Add default gateway
|
||||
if self.gateway is not None:
|
||||
OSUtil.RouteAdd(0 , 0, self.gateway)
|
||||
OSUTIL.route_add(0 , 0, self.gateway)
|
||||
if self.routes is not None:
|
||||
for route in self.routes:
|
||||
OSUtil.RouteAdd(route[0], route[1], route[2])
|
||||
OSUTIL.route_add(route[0], route[1], route[2])
|
||||
|
||||
def ValidateDhcpResponse(request, response):
|
||||
bytesReceived = len(response)
|
||||
if bytesReceived < 0xF6:
|
||||
logger.Error("HandleDhcpResponse: Too few bytes received:{0}",
|
||||
str(bytesReceived))
|
||||
def validate_dhcp_resp(request, response):
|
||||
bytes_recv = len(response)
|
||||
if bytes_recv < 0xF6:
|
||||
logger.error("HandleDhcpResponse: Too few bytes received:{0}",
|
||||
str(bytes_recv))
|
||||
return False
|
||||
|
||||
logger.Verbose("BytesReceived:{0}", hex(bytesReceived))
|
||||
logger.Verbose("DHCP response:{0}", HexDump(response, bytesReceived))
|
||||
logger.verb("BytesReceived:{0}", hex(bytes_recv))
|
||||
logger.verb("DHCP response:{0}", hex_dump(response, bytes_recv))
|
||||
|
||||
# check transactionId, cookie, MAC address cookie should never mismatch
|
||||
# transactionId and MAC address may mismatch if we see a response
|
||||
# transactionId and MAC address may mismatch if we see a response
|
||||
# meant from another machine
|
||||
if not CompareBytes(request, response, 0xEC, 4):
|
||||
logger.Verbose("Cookie not match:\nsend={0},\nreceive={1}",
|
||||
HexDump3(request, 0xEC, 4),
|
||||
HexDump3(response, 0xEC, 4))
|
||||
if not compare_bytes(request, response, 0xEC, 4):
|
||||
logger.verb("Cookie not match:\nsend={0},\nreceive={1}",
|
||||
hex_dump3(request, 0xEC, 4),
|
||||
hex_dump3(response, 0xEC, 4))
|
||||
raise AgentNetworkError("Cookie in dhcp respones "
|
||||
"doesn't match the request")
|
||||
|
||||
if not CompareBytes(request, response, 4, 4):
|
||||
logger.Verbose("TransactionID not match:\nsend={0},\nreceive={1}",
|
||||
HexDump3(request, 4, 4),
|
||||
HexDump3(response, 4, 4))
|
||||
if not compare_bytes(request, response, 4, 4):
|
||||
logger.verb("TransactionID not match:\nsend={0},\nreceive={1}",
|
||||
hex_dump3(request, 4, 4),
|
||||
hex_dump3(response, 4, 4))
|
||||
raise AgentNetworkError("TransactionID in dhcp respones "
|
||||
"doesn't match the request")
|
||||
|
||||
if not CompareBytes(request, response, 0x1C, 6):
|
||||
logger.Verbose("Mac Address not match:\nsend={0},\nreceive={1}",
|
||||
HexDump3(request, 0x1C, 6),
|
||||
HexDump3(response, 0x1C, 6))
|
||||
if not compare_bytes(request, response, 0x1C, 6):
|
||||
logger.verb("Mac Address not match:\nsend={0},\nreceive={1}",
|
||||
hex_dump3(request, 0x1C, 6),
|
||||
hex_dump3(response, 0x1C, 6))
|
||||
raise AgentNetworkError("Mac Addr in dhcp respones "
|
||||
"doesn't match the request")
|
||||
|
||||
def ParseRoute(response, option, i, length, bytesReceived):
|
||||
def parse_route(response, option, i, length, bytes_recv):
|
||||
# http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx
|
||||
logger.Verbose("Routes at offset: {0} with length:{1}",
|
||||
hex(i),
|
||||
logger.verb("Routes at offset: {0} with length:{1}",
|
||||
hex(i),
|
||||
hex(length))
|
||||
routes = []
|
||||
if length < 5:
|
||||
logger.Error("Data too small for option:{0}", str(option))
|
||||
logger.error("Data too small for option:{0}", str(option))
|
||||
j = i + 2
|
||||
while j < (i + length + 2):
|
||||
maskLengthBits = Ord(response[j])
|
||||
maskLengthBytes = (((maskLengthBits + 7) & ~7) >> 3)
|
||||
mask = 0xFFFFFFFF & (0xFFFFFFFF << (32 - maskLengthBits))
|
||||
mask_len_bits = str_to_ord(response[j])
|
||||
mask_len_bytes = (((mask_len_bits + 7) & ~7) >> 3)
|
||||
mask = 0xFFFFFFFF & (0xFFFFFFFF << (32 - mask_len_bits))
|
||||
j += 1
|
||||
net = UnpackBigEndian(response, j, maskLengthBytes)
|
||||
net <<= (32 - maskLengthBytes * 8)
|
||||
net = unpack_big_endian(response, j, mask_len_bytes)
|
||||
net <<= (32 - mask_len_bytes * 8)
|
||||
net &= mask
|
||||
j += maskLengthBytes
|
||||
gateway = UnpackBigEndian(response, j, 4)
|
||||
j += mask_len_bytes
|
||||
gateway = unpack_big_endian(response, j, 4)
|
||||
j += 4
|
||||
routes.append((net, mask, gateway))
|
||||
if j != (i + length + 2):
|
||||
logger.Error("Unable to parse routes")
|
||||
logger.error("Unable to parse routes")
|
||||
return routes
|
||||
|
||||
def ParseIpAddress(response, option, i, length, bytesReceived):
|
||||
if i + 5 < bytesReceived:
|
||||
def parse_ip_addr(response, option, i, length, bytes_recv):
|
||||
if i + 5 < bytes_recv:
|
||||
if length != 4:
|
||||
logger.Error("Endpoint or Default Gateway not 4 bytes")
|
||||
logger.error("Endpoint or Default Gateway not 4 bytes")
|
||||
return None
|
||||
addr = UnpackBigEndian(response, i + 2, 4)
|
||||
IpAddress = IntegerToIpAddressV4String(addr)
|
||||
return IpAddress
|
||||
addr = unpack_big_endian(response, i + 2, 4)
|
||||
ip_addr = int_to_ip4_addr(addr)
|
||||
return ip_addr
|
||||
else:
|
||||
logger.Error("Data too small for option:{0}", str(option))
|
||||
logger.error("Data too small for option:{0}", str(option))
|
||||
return None
|
||||
|
||||
def ParseDhcpResponse(response):
|
||||
def parse_dhcp_resp(response):
|
||||
"""
|
||||
Parse DHCP response:
|
||||
Returns endpoint server or None on error.
|
||||
"""
|
||||
logger.Verbose("parse Dhcp Response")
|
||||
bytesReceived = len(response)
|
||||
logger.verb("parse Dhcp Response")
|
||||
bytes_recv = len(response)
|
||||
endpoint = None
|
||||
gateway = None
|
||||
routes = None
|
||||
|
||||
# Walk all the returned options, parsing out what we need, ignoring the
|
||||
# Walk all the returned options, parsing out what we need, ignoring the
|
||||
# others. We need the custom option 245 to find the the endpoint we talk to,
|
||||
# as well as, to handle some Linux DHCP client incompatibilities,
|
||||
# options 3 for default gateway and 249 for routes. And 255 is end.
|
||||
|
||||
i = 0xF0 # offset to first option
|
||||
while i < bytesReceived:
|
||||
option = Ord(response[i])
|
||||
while i < bytes_recv:
|
||||
option = str_to_ord(response[i])
|
||||
length = 0
|
||||
if (i + 1) < bytesReceived:
|
||||
length = Ord(response[i + 1])
|
||||
logger.Verbose("DHCP option {0} at offset:{1} with length:{2}",
|
||||
hex(option),
|
||||
hex(i),
|
||||
if (i + 1) < bytes_recv:
|
||||
length = str_to_ord(response[i + 1])
|
||||
logger.verb("DHCP option {0} at offset:{1} with length:{2}",
|
||||
hex(option),
|
||||
hex(i),
|
||||
hex(length))
|
||||
if option == 255:
|
||||
logger.Verbose("DHCP packet ended at offset:{0}", hex(i))
|
||||
logger.verb("DHCP packet ended at offset:{0}", hex(i))
|
||||
break
|
||||
elif option == 249:
|
||||
routes = ParseRoute(response, option, i, length, bytesReceived)
|
||||
routes = parse_route(response, option, i, length, bytes_recv)
|
||||
elif option == 3:
|
||||
gateway = ParseIpAddress(response, option, i, length, bytesReceived)
|
||||
logger.Verbose("Default gateway:{0}, at {1}",
|
||||
gateway,
|
||||
gateway = parse_ip_addr(response, option, i, length, bytes_recv)
|
||||
logger.verb("Default gateway:{0}, at {1}",
|
||||
gateway,
|
||||
hex(i))
|
||||
elif option == 245:
|
||||
endpoint = ParseIpAddress(response, option, i, length, bytesReceived)
|
||||
logger.Verbose("Azure wire protocol endpoint:{0}, at {1}",
|
||||
gateway,
|
||||
endpoint = parse_ip_addr(response, option, i, length, bytes_recv)
|
||||
logger.verb("Azure wire protocol endpoint:{0}, at {1}",
|
||||
gateway,
|
||||
hex(i))
|
||||
else:
|
||||
logger.Verbose("Skipping DHCP option:{0} at {1} with length {2}",
|
||||
logger.verb("Skipping DHCP option:{0} at {1} with length {2}",
|
||||
hex(option),
|
||||
hex(i),
|
||||
hex(length))
|
||||
@@ -192,71 +193,71 @@ def ParseDhcpResponse(response):
|
||||
return endpoint, gateway, routes
|
||||
|
||||
|
||||
def AllowBroadcastForDhcp(func):
|
||||
def allow_dhcp_broadcast(func):
|
||||
"""
|
||||
Temporary allow broadcase for dhcp. Remove the route when done.
|
||||
"""
|
||||
def Wrapper(*args, **kwargs):
|
||||
missingDefaultRoute = OSUtil.IsMissingDefaultRoute()
|
||||
ifname = OSUtil.GetInterfaceName()
|
||||
if missingDefaultRoute:
|
||||
OSUtil.SetBroadcastRouteForDhcp(ifname)
|
||||
def wrapper(*args, **kwargs):
|
||||
missing_default_route = OSUTIL.is_missing_default_route()
|
||||
ifname = OSUTIL.get_if_name()
|
||||
if missing_default_route:
|
||||
OSUTIL.set_route_for_dhcp_broadcast(ifname)
|
||||
result = func(*args, **kwargs)
|
||||
if missingDefaultRoute:
|
||||
OSUtil.RemoveBroadcastRouteForDhcp(ifname)
|
||||
if missing_default_route:
|
||||
OSUTIL.remove_route_for_dhcp_broadcast(ifname)
|
||||
return result
|
||||
return Wrapper
|
||||
return wrapper
|
||||
|
||||
def DisableDhcpServiceIfNeeded(func):
|
||||
def disable_dhcp_service(func):
|
||||
"""
|
||||
In some distros, dhcp service needs to be shutdown before agent probe
|
||||
endpoint through dhcp.
|
||||
"""
|
||||
def Wrapper(*args, **kwargs):
|
||||
if OSUtil.IsDhcpEnabled():
|
||||
OSUtil.StopDhcpService()
|
||||
def wrapper(*args, **kwargs):
|
||||
if OSUTIL.is_dhcp_enabled():
|
||||
OSUTIL.stop_dhcp_service()
|
||||
result = func(*args, **kwargs)
|
||||
OSUtil.StartDhcpService()
|
||||
OSUTIL.start_dhcp_service()
|
||||
return result
|
||||
else:
|
||||
return func(*args, **kwargs)
|
||||
return Wrapper
|
||||
return wrapper
|
||||
|
||||
|
||||
@AllowBroadcastForDhcp
|
||||
@DisableDhcpServiceIfNeeded
|
||||
def SendDhcpRequest(request):
|
||||
__SleepDuration = [0, 10, 30, 60, 60]
|
||||
@allow_dhcp_broadcast
|
||||
@disable_dhcp_service
|
||||
def send_dhcp_request(request):
|
||||
__waiting_duration__ = [0, 10, 30, 60, 60]
|
||||
sock = None
|
||||
for duration in __SleepDuration:
|
||||
for duration in __waiting_duration__:
|
||||
try:
|
||||
OSUtil.OpenPortForDhcp()
|
||||
response = SocketSend(request)
|
||||
ValidateDhcpResponse(request, response)
|
||||
OSUTIL.allow_dhcp_broadcast()
|
||||
response = socket_send(request)
|
||||
validate_dhcp_resp(request, response)
|
||||
return response
|
||||
except AgentNetworkError as e:
|
||||
logger.Error("Failed to send DHCP request: {0}", e)
|
||||
logger.error("Failed to send DHCP request: {0}", e)
|
||||
return None
|
||||
finally:
|
||||
if sock:
|
||||
sock.close()
|
||||
time.sleep(duration)
|
||||
|
||||
def SocketSend(request):
|
||||
sock = socket.socket(socket.AF_INET,
|
||||
socket.SOCK_DGRAM,
|
||||
def socket_send(request):
|
||||
sock = socket.socket(socket.AF_INET,
|
||||
socket.SOCK_DGRAM,
|
||||
socket.IPPROTO_UDP)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
||||
sock.bind(("0.0.0.0", 68))
|
||||
sock.bind(("0.0.0.0", 68))
|
||||
sock.sendto(request, ("<broadcast>", 67))
|
||||
sock.settimeout(10)
|
||||
logger.Verbose("Send DHCP request: Setting socket.timeout=10, "
|
||||
logger.verb("Send DHCP request: Setting socket.timeout=10, "
|
||||
"entering recv")
|
||||
response = sock.recv(1024)
|
||||
return response
|
||||
|
||||
def BuildDhcpRequest(macAddress):
|
||||
def build_dhcp_request(mac_addr):
|
||||
"""
|
||||
Build DHCP request string.
|
||||
"""
|
||||
@@ -291,7 +292,7 @@ def BuildDhcpRequest(macAddress):
|
||||
# (struct.pack_into would be good here, but requires Python 2.5)
|
||||
request = [0] * 244
|
||||
|
||||
transactionID = GenTransactionId()
|
||||
trans_id = gen_trans_id()
|
||||
|
||||
# Opcode = 1
|
||||
# HardwareAddressType = 1 (ethernet/MAC)
|
||||
@@ -301,15 +302,15 @@ def BuildDhcpRequest(macAddress):
|
||||
|
||||
# fill in transaction id (random number to ensure response matches request)
|
||||
for a in range(0, 4):
|
||||
request[4 + a] = Ord(transactionID[a])
|
||||
request[4 + a] = str_to_ord(trans_id[a])
|
||||
|
||||
logger.Verbose("BuildDhcpRequest: transactionId:%s,%04X" % (
|
||||
HexDump2(transactionID),
|
||||
UnpackBigEndian(request, 4, 4)))
|
||||
logger.verb("BuildDhcpRequest: transactionId:%s,%04X" % (
|
||||
hex_dump2(trans_id),
|
||||
unpack_big_endian(request, 4, 4)))
|
||||
|
||||
# fill in ClientHardwareAddress
|
||||
for a in range(0, 6):
|
||||
request[0x1C + a] = Ord(macAddress[a])
|
||||
request[0x1C + a] = str_to_ord(mac_addr[a])
|
||||
|
||||
# DHCP Magic Cookie: 99, 130, 83, 99
|
||||
# MessageTypeCode = 53 DHCP Message Type
|
||||
@@ -320,5 +321,5 @@ def BuildDhcpRequest(macAddress):
|
||||
request[0xEC + a] = [99, 130, 83, 99, 53, 1, 1, 255][a]
|
||||
return array.array("B", request)
|
||||
|
||||
def GenTransactionId():
|
||||
def gen_trans_id():
|
||||
return os.urandom(4)
|
||||
|
||||
@@ -23,7 +23,7 @@ import threading
|
||||
import time
|
||||
import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.conf as conf
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
|
||||
class EnvHandler(object):
|
||||
"""
|
||||
@@ -31,35 +31,35 @@ class EnvHandler(object):
|
||||
If dhcp clinet process re-start has occurred, reset routes, dhcp with fabric.
|
||||
|
||||
Monitor scsi disk.
|
||||
If new scsi disk found, set
|
||||
If new scsi disk found, set
|
||||
"""
|
||||
def __init__(self, handlers):
|
||||
self.monitor = EnvMonitor(handlers.dhcpHandler)
|
||||
self.monitor = EnvMonitor(handlers.dhcp_handler)
|
||||
|
||||
def startMonitor(self):
|
||||
def start(self):
|
||||
self.monitor.start()
|
||||
|
||||
def stopMonitor(self):
|
||||
def stop(self):
|
||||
self.monitor.stop()
|
||||
|
||||
class EnvMonitor(object):
|
||||
|
||||
def __init__(self, dhcpHandler):
|
||||
self.dhcpHandler = dhcpHandler
|
||||
def __init__(self, dhcp_handler):
|
||||
self.dhcp_handler = dhcp_handler
|
||||
self.stopped = True
|
||||
self.hostname = None
|
||||
self.dhcpid = None
|
||||
self.server_thread=None
|
||||
|
||||
|
||||
def start(self):
|
||||
if not self.stopped:
|
||||
logger.Info("Stop existing env monitor service.")
|
||||
logger.info("Stop existing env monitor service.")
|
||||
self.stop()
|
||||
|
||||
self.stopped = False
|
||||
logger.Info("Start env monitor service.")
|
||||
logger.info("Start env monitor service.")
|
||||
self.hostname = socket.gethostname()
|
||||
self.dhcpid = OSUtil.GetDhcpProcessId()
|
||||
self.dhcpid = OSUTIL.get_dhcp_pid()
|
||||
self.server_thread = threading.Thread(target = self.monitor)
|
||||
self.server_thread.setDaemon(True)
|
||||
self.server_thread.start()
|
||||
@@ -70,39 +70,39 @@ class EnvMonitor(object):
|
||||
If dhcp clinet process re-start has occurred, reset routes.
|
||||
"""
|
||||
while not self.stopped:
|
||||
OSUtil.RemoveRulesFiles()
|
||||
timeout = conf.Get("OS.RootDeviceScsiTimeout", None)
|
||||
OSUTIL.remove_rules_files()
|
||||
timeout = conf.get("OS.RootDeviceScsiTimeout", None)
|
||||
if timeout is not None:
|
||||
OSUtil.SetScsiDiskTimeout(timeout)
|
||||
if conf.GetSwitch("Provisioning.MonitorHostName", False):
|
||||
self.handleHostnameUpdate()
|
||||
self.handleDhcpClientRestart()
|
||||
OSUTIL.set_scsi_disks_timeout(timeout)
|
||||
if conf.get_switch("Provisioning.MonitorHostName", False):
|
||||
self.handle_hostname_update()
|
||||
self.handle_dhclient_restart()
|
||||
time.sleep(5)
|
||||
|
||||
def handleHostnameUpdate(self):
|
||||
currHostname = socket.gethostname()
|
||||
if currHostname != self.hostname:
|
||||
logger.Info("EnvMonitor: Detected host name change: {0} -> {1}",
|
||||
self.hostname, currHostname)
|
||||
OSUtil.SetHostname(currHostname)
|
||||
OSUtil.PublishHostname(currHostname)
|
||||
self.hostname = currHostname
|
||||
def handle_hostname_update(self):
|
||||
curr_hostname = socket.gethostname()
|
||||
if curr_hostname != self.hostname:
|
||||
logger.info("EnvMonitor: Detected host name change: {0} -> {1}",
|
||||
self.hostname, curr_hostname)
|
||||
OSUTIL.set_hostname(curr_hostname)
|
||||
OSUTIL.publish_hostname(curr_hostname)
|
||||
self.hostname = curr_hostname
|
||||
|
||||
def handleDhcpClientRestart(self):
|
||||
def handle_dhclient_restart(self):
|
||||
if self.dhcpid is None:
|
||||
logger.Warn("Dhcp client is not running. ")
|
||||
self.dhcpid = OSUtil.GetDhcpProcessId()
|
||||
logger.warn("Dhcp client is not running. ")
|
||||
self.dhcpid = OSUTIL.get_dhcp_pid()
|
||||
return
|
||||
|
||||
|
||||
#The dhcp process hasn't changed since last check
|
||||
if os.path.isdir(os.path.join('/proc', self.dhcpid.strip())):
|
||||
return
|
||||
|
||||
newpid = OSUtil.GetDhcpProcessId()
|
||||
|
||||
newpid = OSUTIL.get_dhcp_pid()
|
||||
if newpid is not None and newpid != self.dhcpid:
|
||||
logger.Info("EnvMonitor: Detected dhcp client restart. "
|
||||
logger.info("EnvMonitor: Detected dhcp client restart. "
|
||||
"Restoring routing table.")
|
||||
self.dhcpHandler.configRoutes()
|
||||
self.dhcp_handler.conf_routes()
|
||||
self.dhcpid = newpid
|
||||
|
||||
def stop(self):
|
||||
|
||||
@@ -22,22 +22,22 @@ import time
|
||||
import json
|
||||
import subprocess
|
||||
import azurelinuxagent.logger as logger
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
import azurelinuxagent.protocol as prot
|
||||
from azurelinuxagent.event import AddExtensionEvent, WALAEventOperation
|
||||
from azurelinuxagent.event import add_event, WALAEventOperation
|
||||
from azurelinuxagent.exception import ExtensionError
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
import azurelinuxagent.utils.restutil as restutil
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
|
||||
ValidExtensionStatus = ['transitioning', 'error', 'success', 'warning']
|
||||
VALID_EXTENSION_STATUS = ['transitioning', 'error', 'success', 'warning']
|
||||
|
||||
def validate_has_key(obj, key, fullName):
|
||||
def validate_has_key(obj, key, fullname):
|
||||
if key not in obj:
|
||||
raise ExtensionError("Missing: {0}".format(fullName))
|
||||
raise ExtensionError("Missing: {0}".format(fullname))
|
||||
|
||||
def validate_in_range(val, validRange, name):
|
||||
if val not in validRange:
|
||||
def validate_in_range(val, valid_range, name):
|
||||
if val not in valid_range:
|
||||
raise ExtensionError("Invalid {0}: {1}".format(name, val))
|
||||
|
||||
def try_get(dictionary, key, default=None):
|
||||
@@ -52,12 +52,12 @@ def extension_sub_status_to_v2(substatus):
|
||||
validate_has_key(substatus, 'status', 'substatus/status')
|
||||
validate_has_key(substatus, 'code', 'substatus/code')
|
||||
validate_has_key(substatus, 'formattedMessage', 'substatus/formattedMessage')
|
||||
validate_has_key(substatus['formattedMessage'], 'lang',
|
||||
validate_has_key(substatus['formattedMessage'], 'lang',
|
||||
'substatus/formattedMessage/lang')
|
||||
validate_has_key(substatus['formattedMessage'], 'message',
|
||||
validate_has_key(substatus['formattedMessage'], 'message',
|
||||
'substatus/formattedMessage/message')
|
||||
|
||||
validate_in_range(substatus['status'], ValidExtensionStatus,
|
||||
validate_in_range(substatus['status'], VALID_EXTENSION_STATUS,
|
||||
'substatus/status')
|
||||
status = prot.ExtensionSubStatus()
|
||||
status.name = try_get(substatus, 'name')
|
||||
@@ -66,212 +66,212 @@ def extension_sub_status_to_v2(substatus):
|
||||
status.message = try_get(substatus['formattedMessage'], 'message')
|
||||
return status
|
||||
|
||||
def extension_status_to_v2(extStatus, seqNo):
|
||||
def ext_status_to_v2(ext_status, seq_no):
|
||||
#Check extension status format
|
||||
validate_has_key(extStatus, 'status', 'status')
|
||||
validate_has_key(extStatus['status'], 'status', 'status/status')
|
||||
validate_has_key(extStatus['status'], 'operation', 'status/operation')
|
||||
validate_has_key(extStatus['status'], 'code', 'status/code')
|
||||
validate_has_key(extStatus['status'], 'name', 'status/name')
|
||||
validate_has_key(extStatus['status'], 'formattedMessage',
|
||||
validate_has_key(ext_status, 'status', 'status')
|
||||
validate_has_key(ext_status['status'], 'status', 'status/status')
|
||||
validate_has_key(ext_status['status'], 'operation', 'status/operation')
|
||||
validate_has_key(ext_status['status'], 'code', 'status/code')
|
||||
validate_has_key(ext_status['status'], 'name', 'status/name')
|
||||
validate_has_key(ext_status['status'], 'formattedMessage',
|
||||
'status/formattedMessage')
|
||||
validate_has_key(extStatus['status']['formattedMessage'], 'lang',
|
||||
validate_has_key(ext_status['status']['formattedMessage'], 'lang',
|
||||
'status/formattedMessage/lang')
|
||||
validate_has_key(extStatus['status']['formattedMessage'], 'message',
|
||||
validate_has_key(ext_status['status']['formattedMessage'], 'message',
|
||||
'status/formattedMessage/message')
|
||||
|
||||
validate_in_range(extStatus['status']['status'], ValidExtensionStatus,
|
||||
validate_in_range(ext_status['status']['status'], VALID_EXTENSION_STATUS,
|
||||
'status/status')
|
||||
|
||||
status = prot.ExtensionStatus()
|
||||
status.name = try_get(extStatus['status'], 'name')
|
||||
status.configurationAppliedTime = try_get(extStatus['status'],
|
||||
'configurationAppliedTime')
|
||||
status.operation = try_get(extStatus['status'], 'operation')
|
||||
status.status = try_get(extStatus['status'], 'status')
|
||||
status.code = try_get(extStatus['status'], 'code')
|
||||
status.message = try_get(extStatus['status']['formattedMessage'], 'message')
|
||||
status.sequenceNumber = seqNo
|
||||
|
||||
substatusList = try_get(extStatus['status'], 'substatus', [])
|
||||
for substatus in substatusList:
|
||||
status = prot.ExtensionStatus()
|
||||
status.name = try_get(ext_status['status'], 'name')
|
||||
status.configurationAppliedTime = try_get(ext_status['status'],
|
||||
'configurationAppliedTime')
|
||||
status.operation = try_get(ext_status['status'], 'operation')
|
||||
status.status = try_get(ext_status['status'], 'status')
|
||||
status.code = try_get(ext_status['status'], 'code')
|
||||
status.message = try_get(ext_status['status']['formattedMessage'], 'message')
|
||||
status.sequenceNumber = seq_no
|
||||
|
||||
substatus_list = try_get(ext_status['status'], 'substatus', [])
|
||||
for substatus in substatus_list:
|
||||
status.substatusList.extend(extension_sub_status_to_v2(substatus))
|
||||
return status
|
||||
|
||||
class ExtensionsHandler(object):
|
||||
|
||||
def process(self):
|
||||
protocol = prot.Factory.getDefaultProtocol()
|
||||
extList = protocol.getExtensions()
|
||||
|
||||
handlerStatusList = []
|
||||
for extension in extList.extensions:
|
||||
#TODO handle extension in parallel
|
||||
packageList = protocol.getExtensionPackages(extension)
|
||||
handlerStatus = self.processExtension(extension, packageList)
|
||||
handlerStatusList.append(handlerStatus)
|
||||
protocol = prot.Factory.get_default_protocol()
|
||||
ext_list = protocol.get_extensions()
|
||||
|
||||
return handlerStatusList
|
||||
|
||||
def processExtension(self, extension, packageList):
|
||||
installedVersion = GetInstalledExtensionVersion(extension.name)
|
||||
if installedVersion is not None:
|
||||
ext = ExtensionInstance(extension, packageList,
|
||||
installedVersion, installed=True)
|
||||
h_status_list = []
|
||||
for extension in ext_list.extensions:
|
||||
#TODO handle extension in parallel
|
||||
pkg_list = protocol.get_extension_pkgs(extension)
|
||||
h_status = self.process_extension(extension, pkg_list)
|
||||
h_status_list.append(h_status)
|
||||
|
||||
return h_status_list
|
||||
|
||||
def process_extension(self, extension, pkg_list):
|
||||
installed_version = get_installed_version(extension.name)
|
||||
if installed_version is not None:
|
||||
ext = ExtensionInstance(extension, pkg_list,
|
||||
installed_version, installed=True)
|
||||
else:
|
||||
ext = ExtensionInstance(extension, packageList,
|
||||
ext = ExtensionInstance(extension, pkg_list,
|
||||
extension.properties.version)
|
||||
try:
|
||||
ext.initLog()
|
||||
ext.init_logger()
|
||||
ext.handle()
|
||||
status = ext.collectHandlerStatus()
|
||||
status = ext.collect_handler_status()
|
||||
except ExtensionError as e:
|
||||
logger.Error("Failed to handle extension: {0}-{1}\n {2}",
|
||||
ext.getName(), ext.getVersion(), e)
|
||||
AddExtensionEvent(name=ext.getName(), isSuccess=False,
|
||||
op=ext.getCurrOperation(), message = str(e))
|
||||
extStatus = prot.ExtensionStatus(status='error', code='-1',
|
||||
operation = ext.getCurrOperation(),
|
||||
logger.error("Failed to handle extension: {0}-{1}\n {2}",
|
||||
ext.get_name(), ext.get_version(), e)
|
||||
add_event(name=ext.get_name(), is_success=False,
|
||||
op=ext.get_curr_op(), message = str(e))
|
||||
ext_status = prot.ExtensionStatus(status='error', code='-1',
|
||||
operation = ext.get_curr_op(),
|
||||
message = str(e),
|
||||
sequenceNumber = ext.getSeqNo())
|
||||
status = ext.createHandlerStatus(extStatus)
|
||||
seq_no = ext.get_seq_no())
|
||||
status = ext.create_handler_status(ext_status)
|
||||
status.status = "Ready"
|
||||
return status
|
||||
|
||||
def ParseExtensionDirName(dirName):
|
||||
def parse_extension_dirname(dirname):
|
||||
"""
|
||||
Parse installed extension dir name. Sample: ExtensionName-Version/
|
||||
"""
|
||||
seprator = dirName.rfind('-')
|
||||
seprator = dirname.rfind('-')
|
||||
if seprator < 0:
|
||||
raise ExtensionError("Invalid extenation dir name")
|
||||
return dirName[0:seprator], dirName[seprator + 1:]
|
||||
return dirname[0:seprator], dirname[seprator + 1:]
|
||||
|
||||
def GetInstalledExtensionVersion(targetName):
|
||||
def get_installed_version(target_name):
|
||||
"""
|
||||
Return the highest version instance with the same name
|
||||
"""
|
||||
installedVersion = None
|
||||
libDir = OSUtil.GetLibDir()
|
||||
for dirName in os.listdir(libDir):
|
||||
path = os.path.join(libDir, dirName)
|
||||
if os.path.isdir(path) and dirName.startswith(targetName):
|
||||
name, version = ParseExtensionDirName(dirName)
|
||||
installed_version = None
|
||||
lib_dir = OSUTIL.get_lib_dir()
|
||||
for dir_name in os.listdir(lib_dir):
|
||||
path = os.path.join(lib_dir, dir_name)
|
||||
if os.path.isdir(path) and dir_name.startswith(target_name):
|
||||
name, version = parse_extension_dirname(dir_name)
|
||||
#Here we need to ensure names are exactly the same.
|
||||
if name == targetName:
|
||||
if installedVersion is None or installedVersion < version:
|
||||
installedVersion = version
|
||||
return installedVersion
|
||||
if name == target_name:
|
||||
if installed_version is None or installed_version < version:
|
||||
installed_version = version
|
||||
return installed_version
|
||||
|
||||
class ExtensionInstance(object):
|
||||
def __init__(self, extension, packageList, currVersion, installed=False):
|
||||
def __init__(self, extension, pkg_list, curr_version, installed=False):
|
||||
self.extension = extension
|
||||
self.packageList = packageList
|
||||
self.currVersion = currVersion
|
||||
self.libDir = OSUtil.GetLibDir()
|
||||
self.pkg_list = pkg_list
|
||||
self.curr_version = curr_version
|
||||
self.lib_dir = OSUTIL.get_lib_dir()
|
||||
self.installed = installed
|
||||
self.settings = None
|
||||
|
||||
|
||||
#Extension will have no more than 1 settings instance
|
||||
if len(extension.properties.extensions) > 0:
|
||||
self.settings = extension.properties.extensions[0]
|
||||
self.enabled = False
|
||||
self.currOperation = None
|
||||
self.curr_op = None
|
||||
|
||||
prefix = "[{0}]".format(self.getFullName())
|
||||
self.logger = logger.Logger(logger.DefaultLogger, prefix)
|
||||
|
||||
def initLog(self):
|
||||
prefix = "[{0}]".format(self.get_full_name())
|
||||
self.logger = logger.Logger(logger.default_logger, prefix)
|
||||
|
||||
def init_logger(self):
|
||||
#Init logger appender for extension
|
||||
fileutil.CreateDir(self.getLogDir(), mode=0700)
|
||||
logFile = os.path.join(self.getLogDir(), "CommandExecution.log")
|
||||
self.logger.addLoggerAppender(logger.AppenderType.FILE,
|
||||
logger.LogLevel.INFO, logFile)
|
||||
|
||||
fileutil.mkdir(self.get_log_dir(), mode=0700)
|
||||
log_file = os.path.join(self.get_log_dir(), "CommandExecution.log")
|
||||
self.logger.add_appender(logger.AppenderType.FILE,
|
||||
logger.LogLevel.INFO, log_file)
|
||||
|
||||
def handle(self):
|
||||
self.logger.info("Process extension settings:")
|
||||
self.logger.info(" Name: {0}", self.getName())
|
||||
self.logger.info(" Version: {0}", self.getVersion())
|
||||
|
||||
self.logger.info(" Name: {0}", self.get_name())
|
||||
self.logger.info(" Version: {0}", self.get_version())
|
||||
|
||||
if self.installed:
|
||||
self.logger.info("Installed version:{0}", self.currVersion)
|
||||
handlerStatus = self.getHandlerStatus()
|
||||
self.enabled = (handlerStatus == "Ready")
|
||||
|
||||
state = self.getState()
|
||||
self.logger.info("Installed version:{0}", self.curr_version)
|
||||
h_status = self.get_handler_status()
|
||||
self.enabled = (h_status == "Ready")
|
||||
|
||||
state = self.get_state()
|
||||
if state == 'enabled':
|
||||
self.handleEnable()
|
||||
self.handle_enable()
|
||||
elif state == 'disabled':
|
||||
self.handleDisable()
|
||||
self.handle_disable()
|
||||
elif state == 'uninstall':
|
||||
self.handleDisable()
|
||||
self.handleUninstall()
|
||||
self.handle_disable()
|
||||
self.handle_uninstall()
|
||||
else:
|
||||
raise ExtensionError("Unknown extension state:{0}".format(state))
|
||||
|
||||
def handleEnable(self):
|
||||
targetVersion = self.getTargetVersion()
|
||||
def handle_enable(self):
|
||||
target_version = self.get_target_version()
|
||||
if self.installed:
|
||||
if targetVersion > self.currVersion:
|
||||
self.upgrade(targetVersion)
|
||||
elif targetVersion == self.currVersion:
|
||||
if target_version > self.curr_version:
|
||||
self.upgrade(target_version)
|
||||
elif target_version == self.curr_version:
|
||||
self.enable()
|
||||
else:
|
||||
raise ExtensionError("A newer version has already been installed")
|
||||
else:
|
||||
if targetVersion > self.getVersion():
|
||||
if target_version > self.get_version():
|
||||
#This will happen when auto upgrade policy is enabled
|
||||
self.logger.info("Auto upgrade to new version:{0}",
|
||||
targetVersion)
|
||||
self.currVersion = targetVersion
|
||||
self.logger.info("Auto upgrade to new version:{0}",
|
||||
target_version)
|
||||
self.curr_version = target_version
|
||||
self.download()
|
||||
self.initExtensionDir()
|
||||
self.init_dir()
|
||||
self.install()
|
||||
self.enable()
|
||||
|
||||
def handleDisable(self):
|
||||
def handle_disable(self):
|
||||
if not self.installed or not self.enabled:
|
||||
return
|
||||
self.disable()
|
||||
|
||||
def handleUninstall(self):
|
||||
|
||||
def handle_uninstall(self):
|
||||
if not self.installed:
|
||||
return
|
||||
self.uninstall()
|
||||
|
||||
def upgrade(self, targetVersion):
|
||||
self.logger.info("Upgrade from: {0} to {1}", self.currVersion,
|
||||
targetVersion)
|
||||
self.currOperation=WALAEventOperation.Upgrade
|
||||
def upgrade(self, target_version):
|
||||
self.logger.info("Upgrade from: {0} to {1}", self.curr_version,
|
||||
target_version)
|
||||
self.curr_op=WALAEventOperation.Upgrade
|
||||
old = self
|
||||
new = ExtensionInstance(self.extension, self.packageList, targetVersion)
|
||||
new = ExtensionInstance(self.extension, self.pkg_list, target_version)
|
||||
self.logger.info("Download new extension package")
|
||||
new.initLog()
|
||||
new.init_logger()
|
||||
new.download()
|
||||
self.logger.info("Initialize new extension directory")
|
||||
new.initExtensionDir()
|
||||
new.init_dir()
|
||||
|
||||
old.disable()
|
||||
self.logger.info("Update new extension")
|
||||
new.update()
|
||||
old.uninstall()
|
||||
man = new.loadManifest()
|
||||
if man.isUpdateWithInstall():
|
||||
man = new.load_manifest()
|
||||
if man.is_update_with_install():
|
||||
self.logger.info("Install new extension")
|
||||
new.install()
|
||||
self.logger.info("Enable new extension")
|
||||
new.enable()
|
||||
AddExtensionEvent(name=self.getName(), isSuccess=True,
|
||||
op=self.currOperation, message="")
|
||||
add_event(name=self.get_name(), is_success=True,
|
||||
op=self.curr_op, message="")
|
||||
|
||||
def download(self):
|
||||
self.logger.info("Download extension package")
|
||||
self.currOperation=WALAEventOperation.Download
|
||||
uris = self.getPackageUris()
|
||||
self.curr_op=WALAEventOperation.Download
|
||||
uris = self.get_package_uris()
|
||||
package = None
|
||||
for uri in uris:
|
||||
try:
|
||||
resp = restutil.HttpGet(uri.uri, chkProxy=True)
|
||||
resp = restutil.http_get(uri.uri, chk_proxy=True)
|
||||
package = resp.read()
|
||||
break
|
||||
except restutil.HttpError as e:
|
||||
@@ -279,167 +279,167 @@ class ExtensionInstance(object):
|
||||
|
||||
if package is None:
|
||||
raise ExtensionError("Download extension failed")
|
||||
|
||||
|
||||
self.logger.info("Unpack extension package")
|
||||
pkgFile = os.path.join(self.libDir, os.path.basename(uri.uri) + ".zip")
|
||||
fileutil.SetFileContents(pkgFile, bytearray(package))
|
||||
zipfile.ZipFile(pkgFile).extractall(self.getBaseDir())
|
||||
chmod = "find {0} -type f | xargs chmod u+x".format(self.getBaseDir())
|
||||
shellutil.Run(chmod)
|
||||
AddExtensionEvent(name=self.getName(), isSuccess=True,
|
||||
op=self.currOperation, message="")
|
||||
|
||||
def initExtensionDir(self):
|
||||
pkg_file = os.path.join(self.lib_dir, os.path.basename(uri.uri) + ".zip")
|
||||
fileutil.write_file(pkg_file, bytearray(package))
|
||||
zipfile.ZipFile(pkg_file).extractall(self.get_base_dir())
|
||||
chmod = "find {0} -type f | xargs chmod u+x".format(self.get_base_dir())
|
||||
shellutil.run(chmod)
|
||||
add_event(name=self.get_name(), is_success=True,
|
||||
op=self.curr_op, message="")
|
||||
|
||||
def init_dir(self):
|
||||
self.logger.info("Initialize extension directory")
|
||||
#Save HandlerManifest.json
|
||||
manFile = fileutil.SearchForFile(self.getBaseDir(),
|
||||
man_file = fileutil.search_file(self.get_base_dir(),
|
||||
'HandlerManifest.json')
|
||||
man = fileutil.GetFileContents(manFile, removeBom=True)
|
||||
fileutil.SetFileContents(self.getManifestFile(), man)
|
||||
man = fileutil.read_file(man_file, remove_bom=True)
|
||||
fileutil.write_file(self.get_manifest_file(), man)
|
||||
|
||||
#Create status and config dir
|
||||
statusDir = self.getStatusDir()
|
||||
fileutil.CreateDir(statusDir, mode=0700)
|
||||
configDir = self.getConfigDir()
|
||||
fileutil.CreateDir(configDir, mode=0700)
|
||||
|
||||
status_dir = self.get_status_dir()
|
||||
fileutil.mkdir(status_dir, mode=0700)
|
||||
conf_dir = self.get_conf_dir()
|
||||
fileutil.mkdir(conf_dir, mode=0700)
|
||||
|
||||
#Init handler state to uninstall
|
||||
self.setHandlerStatus("NotReady")
|
||||
self.set_handler_status("NotReady")
|
||||
|
||||
#Save HandlerEnvironment.json
|
||||
self.createHandlerEnvironment()
|
||||
self.create_handler_env()
|
||||
|
||||
def enable(self):
|
||||
self.logger.info("Enable extension.")
|
||||
self.currOperation=WALAEventOperation.Enable
|
||||
man = self.loadManifest()
|
||||
self.launchCommand(man.getEnableCommand())
|
||||
self.setHandlerStatus("Ready")
|
||||
AddExtensionEvent(name=self.getName(), isSuccess=True,
|
||||
op=self.currOperation, message="")
|
||||
self.curr_op=WALAEventOperation.Enable
|
||||
man = self.load_manifest()
|
||||
self.launch_command(man.get_enable_command())
|
||||
self.set_handler_status("Ready")
|
||||
add_event(name=self.get_name(), is_success=True,
|
||||
op=self.curr_op, message="")
|
||||
|
||||
def disable(self):
|
||||
self.logger.info("Disable extension.")
|
||||
self.currOperation=WALAEventOperation.Disable
|
||||
man = self.loadManifest()
|
||||
self.launchCommand(man.getDisableCommand(), timeout=900)
|
||||
self.setHandlerStatus("Ready")
|
||||
AddExtensionEvent(name=self.getName(), isSuccess=True,
|
||||
op=self.currOperation, message="")
|
||||
self.curr_op=WALAEventOperation.Disable
|
||||
man = self.load_manifest()
|
||||
self.launch_command(man.get_disable_command(), timeout=900)
|
||||
self.set_handler_status("Ready")
|
||||
add_event(name=self.get_name(), is_success=True,
|
||||
op=self.curr_op, message="")
|
||||
|
||||
def install(self):
|
||||
self.logger.info("Install extension.")
|
||||
self.currOperation=WALAEventOperation.Install
|
||||
man = self.loadManifest()
|
||||
self.setHandlerStatus("Installing")
|
||||
self.launchCommand(man.getInstallCommand(), timeout=900)
|
||||
self.setHandlerStatus("Ready")
|
||||
AddExtensionEvent(name=self.getName(), isSuccess=True,
|
||||
op=self.currOperation, message="")
|
||||
self.curr_op=WALAEventOperation.Install
|
||||
man = self.load_manifest()
|
||||
self.set_handler_status("Installing")
|
||||
self.launch_command(man.get_install_command(), timeout=900)
|
||||
self.set_handler_status("Ready")
|
||||
add_event(name=self.get_name(), is_success=True,
|
||||
op=self.curr_op, message="")
|
||||
|
||||
def uninstall(self):
|
||||
self.logger.info("Uninstall extension.")
|
||||
self.currOperation=WALAEventOperation.UnInstall
|
||||
man = self.loadManifest()
|
||||
self.launchCommand(man.getUninstallCommand())
|
||||
self.setHandlerStatus("NotReady")
|
||||
AddExtensionEvent(name=self.getName(), isSuccess=True,
|
||||
op=self.currOperation, message="")
|
||||
self.curr_op=WALAEventOperation.UnInstall
|
||||
man = self.load_manifest()
|
||||
self.launch_command(man.get_uninstall_command())
|
||||
self.set_handler_status("NotReady")
|
||||
add_event(name=self.get_name(), is_success=True,
|
||||
op=self.curr_op, message="")
|
||||
|
||||
def update(self):
|
||||
self.logger.info("Update extension.")
|
||||
self.currOperation=WALAEventOperation.Update
|
||||
man = self.loadManifest()
|
||||
self.launchCommand(man.getUpdateCommand(), timeout=900)
|
||||
AddExtensionEvent(name=self.getName(), isSuccess=True,
|
||||
op=self.currOperation, message="")
|
||||
|
||||
def createHandlerStatus(self, extStatus, heartbeat=None):
|
||||
status = prot.ExtensionHandlerStatus()
|
||||
status.handlerName = self.getName()
|
||||
status.handlerVersion = self.getVersion()
|
||||
status.status = self.getHandlerStatus()
|
||||
status.extensionStatusList.append(extStatus)
|
||||
return status
|
||||
self.curr_op=WALAEventOperation.Update
|
||||
man = self.load_manifest()
|
||||
self.launch_command(man.get_update_command(), timeout=900)
|
||||
add_event(name=self.get_name(), is_success=True,
|
||||
op=self.curr_op, message="")
|
||||
|
||||
def collectHandlerStatus(self):
|
||||
man = self.loadManifest()
|
||||
heartbeat=None
|
||||
if man.isReportHeartbeat():
|
||||
heartbeat = self.getHeartbeat()
|
||||
extStatus = self.getExtensionStatus()
|
||||
status= self.createHandlerStatus(extStatus, heartbeat)
|
||||
status.status = self.getHandlerStatus()
|
||||
if heartbeat is not None:
|
||||
status.status = heartbeat['status']
|
||||
status.extensionStatusList.append(extStatus)
|
||||
def create_handler_status(self, ext_status, heartbeat=None):
|
||||
status = prot.ExtensionHandlerStatus()
|
||||
status.handlerName = self.get_name()
|
||||
status.handlerVersion = self.get_version()
|
||||
status.status = self.get_handler_status()
|
||||
status.extensionStatusList.append(ext_status)
|
||||
return status
|
||||
|
||||
def getExtensionStatus(self):
|
||||
extStatusFile = self.getStatusFile()
|
||||
def collect_handler_status(self):
|
||||
man = self.load_manifest()
|
||||
heartbeat=None
|
||||
if man.is_report_heartbeat():
|
||||
heartbeat = self.collect_heartbeat()
|
||||
ext_status = self.collect_extension_status()
|
||||
status= self.create_handler_status(ext_status, heartbeat)
|
||||
status.status = self.get_handler_status()
|
||||
if heartbeat is not None:
|
||||
status.status = heartbeat['status']
|
||||
status.extensionStatusList.append(ext_status)
|
||||
return status
|
||||
|
||||
def collect_extension_status(self):
|
||||
ext_status_file = self.get_status_file()
|
||||
try:
|
||||
extStatusJson = fileutil.GetFileContents(extStatusFile)
|
||||
extStatus = json.loads(extStatusJson)
|
||||
ext_status_str = fileutil.read_file(ext_status_file)
|
||||
ext_status = json.loads(ext_status_str)
|
||||
except IOError as e:
|
||||
raise ExtensionError("Failed to get status file: {0}".format(e))
|
||||
except ValueError as e:
|
||||
raise ExtensionError("Malformed status file: {0}".format(e))
|
||||
return extension_status_to_v2(extStatus[0],
|
||||
return ext_status_to_v2(ext_status[0],
|
||||
self.settings.sequenceNumber)
|
||||
|
||||
def getHandlerStatus(self):
|
||||
handlerStatus = "uninstalled"
|
||||
handlerStatusFile = self.getHandlerStateFile()
|
||||
|
||||
def get_handler_status(self):
|
||||
h_status = "uninstalled"
|
||||
h_status_file = self.get_handler_state_file()
|
||||
try:
|
||||
handlerStatus = fileutil.GetFileContents(handlerStatusFile)
|
||||
return handlerStatus
|
||||
h_status = fileutil.read_file(h_status_file)
|
||||
return h_status
|
||||
except IOError as e:
|
||||
raise ExtensionError("Failed to get handler status: {0}".format(e))
|
||||
|
||||
def setHandlerStatus(self, status):
|
||||
handlerStatusFile = self.getHandlerStateFile()
|
||||
def set_handler_status(self, status):
|
||||
h_status_file = self.get_handler_state_file()
|
||||
try:
|
||||
fileutil.SetFileContents(handlerStatusFile, status)
|
||||
fileutil.write_file(h_status_file, status)
|
||||
except IOError as e:
|
||||
raise ExtensionError("Failed to set handler status: {0}".format(e))
|
||||
|
||||
def getHeartbeat(self):
|
||||
def collect_heartbeat(self):
|
||||
self.logger.info("Collect heart beat")
|
||||
heartbeatFile = os.path.join(OSUtil.GetLibDir(),
|
||||
self.getHeartbeatFile())
|
||||
if not os.path.isfile(heartbeatFile):
|
||||
heartbeat_file = os.path.join(OSUTIL.get_lib_dir(),
|
||||
self.get_heartbeat_file())
|
||||
if not os.path.isfile(heartbeat_file):
|
||||
raise ExtensionError("Failed to get heart beat file")
|
||||
if not self.isResponsive(heartbeatFile):
|
||||
if not self.is_responsive(heartbeat_file):
|
||||
return {
|
||||
"status": "Unresponsive",
|
||||
"code": -1,
|
||||
"message": "Extension heartbeat is not responsive"
|
||||
}
|
||||
}
|
||||
try:
|
||||
heartbeatJson = fileutil.GetFileContents(heartbeatFile)
|
||||
heartbeat = json.loads(heartbeatJson)[0]['heartbeat']
|
||||
heartbeat_json = fileutil.read_file(heartbeat_file)
|
||||
heartbeat = json.loads(heartbeat_json)[0]['heartbeat']
|
||||
except IOError as e:
|
||||
raise ExtensionError("Failed to get heartbeat file:{0}".format(e))
|
||||
except ValueError as e:
|
||||
raise ExtensionError("Malformed heartbeat file: {0}".format(e))
|
||||
return heartbeat
|
||||
|
||||
def isResponsive(self, heartbeatFile):
|
||||
lastUpdate=int(time.time()-os.stat(heartbeatFile).st_mtime)
|
||||
return lastUpdate > 600 # not updated for more than 10 min
|
||||
def is_responsive(self, heartbeat_file):
|
||||
last_update=int(time.time()-os.stat(heartbeat_file).st_mtime)
|
||||
return last_update > 600 # not updated for more than 10 min
|
||||
|
||||
def launchCommand(self, cmd, timeout=300):
|
||||
def launch_command(self, cmd, timeout=300):
|
||||
self.logger.info("Launch command:{0}", cmd)
|
||||
baseDir = self.getBaseDir()
|
||||
self.updateSettings()
|
||||
base_dir = self.get_base_dir()
|
||||
self.update_settings()
|
||||
try:
|
||||
devnull = open(os.devnull, 'w')
|
||||
child = subprocess.Popen(baseDir + "/" + cmd, shell=True,
|
||||
cwd=baseDir, stdout=devnull)
|
||||
child = subprocess.Popen(base_dir + "/" + cmd, shell=True,
|
||||
cwd=base_dir, stdout=devnull)
|
||||
except Exception as e:
|
||||
#TODO do not catch all exception
|
||||
raise ExtensionError("Failed to launch: {0}, {1}".format(cmd, e))
|
||||
|
||||
|
||||
retry = timeout / 5
|
||||
while retry > 0 and child.poll == None:
|
||||
time.sleep(5)
|
||||
@@ -451,11 +451,11 @@ class ExtensionInstance(object):
|
||||
ret = child.wait()
|
||||
if ret == None or ret != 0:
|
||||
raise ExtensionError("Non-zero exit code: {0}, {1}".format(ret, cmd))
|
||||
|
||||
def loadManifest(self):
|
||||
manFile = self.getManifestFile()
|
||||
|
||||
def load_manifest(self):
|
||||
man_file = self.get_manifest_file()
|
||||
try:
|
||||
data = json.loads(fileutil.GetFileContents(manFile))
|
||||
data = json.loads(fileutil.read_file(man_file))
|
||||
except IOError as e:
|
||||
raise ExtensionError('Failed to load manifest file.')
|
||||
except ValueError as e:
|
||||
@@ -464,141 +464,141 @@ class ExtensionInstance(object):
|
||||
return HandlerManifest(data[0])
|
||||
|
||||
|
||||
def updateSettings(self):
|
||||
def update_settings(self):
|
||||
if self.settings is None:
|
||||
self.logger.verbose("Extension has no settings")
|
||||
return
|
||||
|
||||
handlerSettings = {
|
||||
settings = {
|
||||
'publicSettings': self.settings.publicSettings,
|
||||
'protectedSettings': self.settings.privateSettings,
|
||||
'protectedSettingsCertThumbprint': self.settings.certificateThumbprint
|
||||
}
|
||||
extSettings = {
|
||||
ext_settings = {
|
||||
"runtimeSettings":[{
|
||||
"handlerSettings": handlerSettings
|
||||
"handlerSettings": settings
|
||||
}]
|
||||
}
|
||||
fileutil.SetFileContents(self.getSettingsFile(), json.dumps(extSettings))
|
||||
fileutil.write_file(self.get_settings_file(), json.dumps(ext_settings))
|
||||
|
||||
latest = os.path.join(self.getConfigDir(), "latest")
|
||||
fileutil.SetFileContents(latest, self.settings.sequenceNumber)
|
||||
latest = os.path.join(self.get_conf_dir(), "latest")
|
||||
fileutil.write_file(latest, self.settings.sequenceNumber)
|
||||
|
||||
def createHandlerEnvironment(self):
|
||||
def create_handler_env(self):
|
||||
env = [{
|
||||
"name": self.getName(),
|
||||
"version" : self.getVersion(),
|
||||
"name": self.get_name(),
|
||||
"version" : self.get_version(),
|
||||
"handlerEnvironment" : {
|
||||
"logFolder" : self.getLogDir(),
|
||||
"configFolder" : self.getConfigDir(),
|
||||
"statusFolder" : self.getStatusDir(),
|
||||
"heartbeatFile" : self.getHeartbeatFile()
|
||||
"logFolder" : self.get_log_dir(),
|
||||
"configFolder" : self.get_conf_dir(),
|
||||
"statusFolder" : self.get_status_dir(),
|
||||
"heartbeatFile" : self.get_heartbeat_file()
|
||||
}
|
||||
}]
|
||||
fileutil.SetFileContents(self.getEnvironmentFile(),
|
||||
fileutil.write_file(self.get_env_file(),
|
||||
json.dumps(env))
|
||||
|
||||
def getTargetVersion(self):
|
||||
version = self.getVersion()
|
||||
updatePolicy = self.getUpgradePolicy()
|
||||
if updatePolicy is None or updatePolicy.lower() != 'auto':
|
||||
def get_target_version(self):
|
||||
version = self.get_version()
|
||||
update_policy = self.get_upgrade_policy()
|
||||
if update_policy is None or update_policy.lower() != 'auto':
|
||||
return version
|
||||
|
||||
|
||||
major = version.split('.')[0]
|
||||
if major is None:
|
||||
raise ExtensionError("Wrong version format: {0}".format(version))
|
||||
|
||||
packages = filter(lambda x : x.version.startswith(major + "."),
|
||||
self.packageList.versions)
|
||||
packages = filter(lambda x : x.version.startswith(major + "."),
|
||||
self.pkg_list.versions)
|
||||
packages = sorted(packages, key=lambda x: x.version, reverse=True)
|
||||
if len(packages) <= 0:
|
||||
raise ExtensionError("Can't find version: {0}.*".format(major))
|
||||
|
||||
return packages[0].version
|
||||
|
||||
def getPackageUris(self):
|
||||
version = self.getVersion()
|
||||
packages = self.packageList.versions
|
||||
def get_package_uris(self):
|
||||
version = self.get_version()
|
||||
packages = self.pkg_list.versions
|
||||
if packages is None:
|
||||
raise ExtensionError("Package uris is None.")
|
||||
|
||||
|
||||
for package in packages:
|
||||
if package.version == version:
|
||||
return package.uris
|
||||
|
||||
raise ExtensionError("Can't get package uris for {0}.".format(version))
|
||||
|
||||
def getCurrOperation(self):
|
||||
return self.currOperation
|
||||
|
||||
def getName(self):
|
||||
|
||||
def get_curr_op(self):
|
||||
return self.curr_op
|
||||
|
||||
def get_name(self):
|
||||
return self.extension.name
|
||||
|
||||
def getVersion(self):
|
||||
def get_version(self):
|
||||
return self.extension.properties.version
|
||||
|
||||
def getState(self):
|
||||
def get_state(self):
|
||||
return self.extension.properties.state
|
||||
|
||||
def getSeqNo(self):
|
||||
def get_seq_no(self):
|
||||
return self.settings.sequenceNumber
|
||||
|
||||
def getUpgradePolicy(self):
|
||||
|
||||
def get_upgrade_policy(self):
|
||||
return self.extension.properties.upgradePolicy
|
||||
|
||||
def getFullName(self):
|
||||
return "{0}-{1}".format(self.getName(), self.currVersion)
|
||||
|
||||
def getBaseDir(self):
|
||||
return os.path.join(OSUtil.GetLibDir(), self.getFullName())
|
||||
def get_full_name(self):
|
||||
return "{0}-{1}".format(self.get_name(), self.curr_version)
|
||||
|
||||
def getStatusDir(self):
|
||||
return os.path.join(self.getBaseDir(), "status")
|
||||
def get_base_dir(self):
|
||||
return os.path.join(OSUTIL.get_lib_dir(), self.get_full_name())
|
||||
|
||||
def getStatusFile(self):
|
||||
return os.path.join(self.getStatusDir(),
|
||||
def get_status_dir(self):
|
||||
return os.path.join(self.get_base_dir(), "status")
|
||||
|
||||
def get_status_file(self):
|
||||
return os.path.join(self.get_status_dir(),
|
||||
"{0}.status".format(self.settings.sequenceNumber))
|
||||
|
||||
def getConfigDir(self):
|
||||
return os.path.join(self.getBaseDir(), 'config')
|
||||
def get_conf_dir(self):
|
||||
return os.path.join(self.get_base_dir(), 'config')
|
||||
|
||||
def getSettingsFile(self):
|
||||
return os.path.join(self.getConfigDir(),
|
||||
def get_settings_file(self):
|
||||
return os.path.join(self.get_conf_dir(),
|
||||
"{0}.settings".format(self.settings.sequenceNumber))
|
||||
|
||||
def getHandlerStateFile(self):
|
||||
return os.path.join(self.getConfigDir(), 'HandlerState')
|
||||
def get_handler_state_file(self):
|
||||
return os.path.join(self.get_conf_dir(), 'HandlerState')
|
||||
|
||||
def getHeartbeatFile(self):
|
||||
return os.path.join(self.getBaseDir(), 'heartbeat.log')
|
||||
def get_heartbeat_file(self):
|
||||
return os.path.join(self.get_base_dir(), 'heartbeat.log')
|
||||
|
||||
def getManifestFile(self):
|
||||
return os.path.join(self.getBaseDir(), 'HandlerManifest.json')
|
||||
def get_manifest_file(self):
|
||||
return os.path.join(self.get_base_dir(), 'HandlerManifest.json')
|
||||
|
||||
def getEnvironmentFile(self):
|
||||
return os.path.join(self.getBaseDir(), 'HandlerEnvironment.json')
|
||||
def get_env_file(self):
|
||||
return os.path.join(self.get_base_dir(), 'HandlerEnvironment.json')
|
||||
|
||||
def getLogDir(self):
|
||||
return os.path.join(OSUtil.GetExtLogDir(), self.getName(),
|
||||
self.currVersion)
|
||||
def get_log_dir(self):
|
||||
return os.path.join(OSUTIL.get_ext_log_dir(), self.get_name(),
|
||||
self.curr_version)
|
||||
|
||||
class HandlerEnvironment(object):
|
||||
def __init__(self, data):
|
||||
self.data = data
|
||||
|
||||
def getVersion(self):
|
||||
|
||||
def get_version(self):
|
||||
return self.data["version"]
|
||||
|
||||
def getLogDir(self):
|
||||
def get_log_dir(self):
|
||||
return self.data["handlerEnvironment"]["logFolder"]
|
||||
|
||||
def getConfigDir(self):
|
||||
def get_conf_dir(self):
|
||||
return self.data["handlerEnvironment"]["configFolder"]
|
||||
|
||||
def getStatusDir(self):
|
||||
def get_status_dir(self):
|
||||
return self.data["handlerEnvironment"]["statusFolder"]
|
||||
|
||||
def getHeartbeatFile(self):
|
||||
def get_heartbeat_file(self):
|
||||
return self.data["handlerEnvironment"]["heartbeatFile"]
|
||||
|
||||
class HandlerManifest(object):
|
||||
@@ -607,39 +607,39 @@ class HandlerManifest(object):
|
||||
raise ExtensionError('Malformed manifest file.')
|
||||
self.data = data
|
||||
|
||||
def getName(self):
|
||||
def get_name(self):
|
||||
return self.data["name"]
|
||||
|
||||
def getVersion(self):
|
||||
def get_version(self):
|
||||
return self.data["version"]
|
||||
|
||||
def getInstallCommand(self):
|
||||
def get_install_command(self):
|
||||
return self.data['handlerManifest']["installCommand"]
|
||||
|
||||
def getUninstallCommand(self):
|
||||
def get_uninstall_command(self):
|
||||
return self.data['handlerManifest']["uninstallCommand"]
|
||||
|
||||
def getUpdateCommand(self):
|
||||
def get_update_command(self):
|
||||
return self.data['handlerManifest']["updateCommand"]
|
||||
|
||||
def getEnableCommand(self):
|
||||
def get_enable_command(self):
|
||||
return self.data['handlerManifest']["enableCommand"]
|
||||
|
||||
def getDisableCommand(self):
|
||||
def get_disable_command(self):
|
||||
return self.data['handlerManifest']["disableCommand"]
|
||||
|
||||
def isRebootAfterInstall(self):
|
||||
def is_reboot_after_install(self):
|
||||
#TODO handle reboot after install
|
||||
if "rebootAfterInstall" not in self.data['handlerManifest']:
|
||||
return False
|
||||
return self.data['handlerManifest']["rebootAfterInstall"]
|
||||
|
||||
def isReportHeartbeat(self):
|
||||
def is_report_heartbeat(self):
|
||||
if "reportHeartbeat" not in self.data['handlerManifest']:
|
||||
return False
|
||||
return self.data['handlerManifest']["reportHeartbeat"]
|
||||
|
||||
def isUpdateWithInstall(self):
|
||||
def is_update_with_install(self):
|
||||
if "updateMode" not in self.data['handlerManifest']:
|
||||
return False
|
||||
if "updateMode" in self.data:
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
from init import InitHandler
|
||||
from run import RunHandler
|
||||
from run import MainHandler
|
||||
from scvmm import ScvmmHandler
|
||||
from dhcp import DhcpHandler
|
||||
from env import EnvHandler
|
||||
@@ -25,16 +25,16 @@ from provision import ProvisionHandler
|
||||
from resourceDisk import ResourceDiskHandler
|
||||
from extension import ExtensionsHandler
|
||||
from deprovision import DeprovisionHandler
|
||||
|
||||
|
||||
class DefaultHandlerFactory(object):
|
||||
def __init__(self):
|
||||
self.initHandler = InitHandler()
|
||||
self.runHandler = RunHandler(self)
|
||||
self.scvmmHandler = ScvmmHandler()
|
||||
self.dhcpHandler = DhcpHandler()
|
||||
self.envHandler = EnvHandler(self)
|
||||
self.provisionHandler = ProvisionHandler()
|
||||
self.resourceDiskHandler = ResourceDiskHandler()
|
||||
self.extensionHandler = ExtensionsHandler()
|
||||
self.deprovisionHandler = DeprovisionHandler()
|
||||
self.init_handler = InitHandler()
|
||||
self.main_handler = MainHandler(self)
|
||||
self.scvmm_handler = ScvmmHandler()
|
||||
self.dhcp_handler = DhcpHandler()
|
||||
self.env_handler = EnvHandler(self)
|
||||
self.provision_handler = ProvisionHandler()
|
||||
self.resource_disk_handler = ResourceDiskHandler()
|
||||
self.extension_handler = ExtensionsHandler()
|
||||
self.deprovision_handler = DeprovisionHandler()
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@
|
||||
import os
|
||||
import azurelinuxagent.conf as conf
|
||||
import azurelinuxagent.logger as logger
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
|
||||
|
||||
@@ -28,22 +28,22 @@ class InitHandler(object):
|
||||
def init(self, verbose):
|
||||
#Init stdout log
|
||||
level = logger.LogLevel.VERBOSE if verbose else logger.LogLevel.INFO
|
||||
logger.AddLoggerAppender(logger.AppenderType.STDOUT, level)
|
||||
logger.add_logger_appender(logger.AppenderType.STDOUT, level)
|
||||
|
||||
#Init config
|
||||
configPath = OSUtil.GetConfigurationPath()
|
||||
conf.LoadConfiguration(configPath)
|
||||
|
||||
conf_file_path = OSUTIL.get_conf_file_path()
|
||||
conf.load_conf(conf_file_path)
|
||||
|
||||
#Init log
|
||||
verbose = verbose or conf.GetSwitch("Logs.Verbose", False)
|
||||
verbose = verbose or conf.get_switch("Logs.Verbose", False)
|
||||
level = logger.LogLevel.VERBOSE if verbose else logger.LogLevel.INFO
|
||||
logger.AddLoggerAppender(logger.AppenderType.FILE, level,
|
||||
logger.add_logger_appender(logger.AppenderType.FILE, level,
|
||||
path="/var/log/waagent.log")
|
||||
logger.AddLoggerAppender(logger.AppenderType.CONSOLE, level,
|
||||
logger.add_logger_appender(logger.AppenderType.CONSOLE, level,
|
||||
path="/dev/console")
|
||||
|
||||
|
||||
#Create lib dir
|
||||
fileutil.CreateDir(OSUtil.GetLibDir(), mode=0700)
|
||||
os.chdir(OSUtil.GetLibDir())
|
||||
fileutil.mkdir(OSUTIL.get_lib_dir(), mode=0700)
|
||||
os.chdir(OSUTIL.get_lib_dir())
|
||||
|
||||
|
||||
|
||||
@@ -17,12 +17,12 @@
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
|
||||
def GetOSUtil():
|
||||
def get_osutil():
|
||||
from azurelinuxagent.distro.default.osutil import DefaultOSUtil
|
||||
return DefaultOSUtil()
|
||||
|
||||
def GetHandlers():
|
||||
def get_handlers():
|
||||
from azurelinuxagent.distro.default.handlerFactory import DefaultHandlerFactory
|
||||
return DefaultHandlerFactory()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -18,23 +18,23 @@
|
||||
|
||||
import os
|
||||
import re
|
||||
import pwd
|
||||
import shutil
|
||||
import socket
|
||||
import array
|
||||
import struct
|
||||
import fcntl
|
||||
import time
|
||||
import pwd
|
||||
import fcntl
|
||||
import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
import azurelinuxagent.utils.textutil as textutil
|
||||
|
||||
RulesFiles = [ "/lib/udev/rules.d/75-persistent-net-generator.rules",
|
||||
"/etc/udev/rules.d/70-persistent-net.rules" ]
|
||||
__RULES_FILES__ = [ "/lib/udev/rules.d/75-persistent-net-generator.rules",
|
||||
"/etc/udev/rules.d/70-persistent-net.rules" ]
|
||||
|
||||
"""
|
||||
Define distro specific behavior. OSUtil class defines default behavior
|
||||
Define distro specific behavior. OSUtil class defines default behavior
|
||||
for all distros. Each concrete distro classes could overwrite default behavior
|
||||
if needed.
|
||||
"""
|
||||
@@ -45,70 +45,70 @@ class OSUtilError(Exception):
|
||||
class DefaultOSUtil(object):
|
||||
|
||||
def __init__(self):
|
||||
self.libDir = "/var/lib/waagent"
|
||||
self.extLogDir = "/var/log/azure"
|
||||
self.dvdMountPoint = "/mnt/cdrom/secure"
|
||||
self.ovfenvPathOnDvd = "/mnt/cdrom/secure/ovf-env.xml"
|
||||
self.agentPidPath = "/var/run/waagent.pid"
|
||||
self.passwdPath = "/etc/shadow"
|
||||
self.lib_dir = "/var/lib/waagent"
|
||||
self.ext_log_dir = "/var/log/azure"
|
||||
self.dvd_mount_point = "/mnt/cdrom/secure"
|
||||
self.ovf_env_file_path = "/mnt/cdrom/secure/ovf-env.xml"
|
||||
self.agent_pid_file_path = "/var/run/waagent.pid"
|
||||
self.passwd_file_path = "/etc/shadow"
|
||||
self.home = '/home'
|
||||
self.sshdConfigPath = '/etc/ssh/sshd_config'
|
||||
self.opensslCmd = '/usr/bin/openssl'
|
||||
self.configPath = '/etc/waagent.conf'
|
||||
self.sshd_conf_file_path = '/etc/ssh/sshd_config'
|
||||
self.openssl_cmd = '/usr/bin/openssl'
|
||||
self.conf_file_path = '/etc/waagent.conf'
|
||||
self.selinux=None
|
||||
|
||||
def GetLibDir(self):
|
||||
return self.libDir
|
||||
def get_lib_dir(self):
|
||||
return self.lib_dir
|
||||
|
||||
def GetExtLogDir(self):
|
||||
return self.extLogDir
|
||||
def get_ext_log_dir(self):
|
||||
return self.ext_log_dir
|
||||
|
||||
def GetDvdMountPoint(self):
|
||||
return self.dvdMountPoint
|
||||
def get_dvd_mount_point(self):
|
||||
return self.dvd_mount_point
|
||||
|
||||
def GetConfigurationPath(self):
|
||||
return self.configPath
|
||||
def get_conf_file_path(self):
|
||||
return self.conf_file_path
|
||||
|
||||
def GetOvfEnvPathOnDvd(self):
|
||||
return self.ovfenvPathOnDvd
|
||||
def get_ovf_env_file_path_on_dvd(self):
|
||||
return self.ovf_env_file_path
|
||||
|
||||
def GetAgentPidPath(self):
|
||||
return self.agentPidPath
|
||||
def get_agent_pid_file_path(self):
|
||||
return self.agent_pid_file_path
|
||||
|
||||
def GetOpensslCmd(self):
|
||||
return self.opensslCmd
|
||||
def get_openssl_cmd(self):
|
||||
return self.openssl_cmd
|
||||
|
||||
def UpdateUserAccount(self, userName, password, expiration=None):
|
||||
def set_user_account(self, username, password, expiration=None):
|
||||
"""
|
||||
Update password and ssh key for user account.
|
||||
New account will be created if not exists.
|
||||
"""
|
||||
if userName is None:
|
||||
if username is None:
|
||||
raise OSUtilError("User name is empty")
|
||||
|
||||
if self.IsSysUser(userName):
|
||||
if self.is_sys_user(username):
|
||||
raise OSUtilError(("User {0} is a system user. "
|
||||
"Will not set passwd.").format(userName))
|
||||
"Will not set passwd.").format(username))
|
||||
|
||||
userentry = self.GetUserEntry(userName)
|
||||
userentry = self.get_userentry(username)
|
||||
if userentry is None:
|
||||
self.CreateUserAccount(userName, expiration)
|
||||
|
||||
self.ConfigSudoer(userName, password is None)
|
||||
self.useradd(username, expiration)
|
||||
|
||||
def GetUserEntry(self, userName):
|
||||
self.conf_sudoer(username, password is None)
|
||||
|
||||
def get_userentry(self, username):
|
||||
try:
|
||||
return pwd.getpwnam(userName)
|
||||
return pwd.getpwnam(username)
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
def IsSysUser(self, userName):
|
||||
userentry = self.GetUserEntry(userName)
|
||||
def is_sys_user(self, username):
|
||||
userentry = self.get_userentry(username)
|
||||
uidmin = None
|
||||
try:
|
||||
uidminDef = fileutil.GetLineStartingWith("UID_MIN", "/etc/login.defs")
|
||||
if uidminDef is not None:
|
||||
uidmin = int(uidminDef.split()[1])
|
||||
uidmin_def = fileutil.get_line_startingwith("UID_MIN", "/etc/login.defs")
|
||||
if uidmin_def is not None:
|
||||
uidmin = int(uidmin_def.split()[1])
|
||||
except IOError as e:
|
||||
pass
|
||||
if uidmin == None:
|
||||
@@ -117,328 +117,328 @@ class DefaultOSUtil(object):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def CreateUserAccount(self, userName, expiration=None):
|
||||
|
||||
def useradd(self, username, expiration=None):
|
||||
if expiration is not None:
|
||||
cmd = "useradd -m {0} -e {1}".format(userName, expiration)
|
||||
cmd = "useradd -m {0} -e {1}".format(username, expiration)
|
||||
else:
|
||||
cmd = "useradd -m {0}".format(userName)
|
||||
retcode, out = shellutil.RunGetOutput(cmd)
|
||||
cmd = "useradd -m {0}".format(username)
|
||||
retcode, out = shellutil.run_get_output(cmd)
|
||||
if retcode != 0:
|
||||
raise OSUtilError(("Failed to create user account:{0}, "
|
||||
"retcode:{1}, "
|
||||
"output:{2}").format(userName, retcode, out))
|
||||
"output:{2}").format(username, retcode, out))
|
||||
|
||||
def ChangePassword(self, userName, password, useSalt=True, saltType=6,
|
||||
saltLength=10):
|
||||
passwdHash = textutil.GetPasswordHash(password, useSalt, saltType,
|
||||
saltLength)
|
||||
def chpasswd(self, username, password, use_salt=True, salt_type=6,
|
||||
salt_len=10):
|
||||
passwd_hash = textutil.gen_password_hash(password, use_salt, salt_type,
|
||||
salt_len)
|
||||
try:
|
||||
passwdContent = fileutil.GetFileContents(self.passwdPath)
|
||||
passwd = passwdContent.split("\n")
|
||||
newPasswd = filter(lambda x : not x.startswith(userName), passwd)
|
||||
newPasswd.append("{0}:{1}:14600::::::".format(userName, passwdHash))
|
||||
fileutil.SetFileContents(self.passwdPath, "\n".join(newPasswd))
|
||||
passwd_content = fileutil.read_file(self.passwd_file_path)
|
||||
passwd = passwd_content.split("\n")
|
||||
new_passwd = filter(lambda x : not x.startswith(username), passwd)
|
||||
new_passwd.append("{0}:{1}:14600::::::".format(username, passwd_hash))
|
||||
fileutil.write_file(self.passwd_file_path, "\n".join(new_passwd))
|
||||
except IOError as e:
|
||||
raise OSUtilError(("Failed to set password for {0}: {1}"
|
||||
"").format(userName, e))
|
||||
"").format(username, e))
|
||||
|
||||
def ConfigSudoer(self, userName, nopasswd):
|
||||
def conf_sudoer(self, username, nopasswd):
|
||||
# for older distros create sudoers.d
|
||||
if not os.path.isdir('/etc/sudoers.d/'):
|
||||
# create the /etc/sudoers.d/ directory
|
||||
os.mkdir('/etc/sudoers.d/')
|
||||
# add the include of sudoers.d to the /etc/sudoers
|
||||
sudoers = '\n' + '#includedir /etc/sudoers.d/\n'
|
||||
fileutil.AppendFileContents('/etc/sudoers', sudoers)
|
||||
fileutil.append_file('/etc/sudoers', sudoers)
|
||||
sudoer = None
|
||||
if nopasswd:
|
||||
sudoer = "{0} ALL = (ALL) NOPASSWD\n".format(userName)
|
||||
sudoer = "{0} ALL = (ALL) NOPASSWD\n".format(username)
|
||||
else:
|
||||
sudoer = "{0} ALL = (ALL) ALL\n".format(userName)
|
||||
fileutil.AppendFileContents('/etc/sudoers.d/waagent', sudoer)
|
||||
fileutil.ChangeMod('/etc/sudoers.d/waagent', 0440)
|
||||
sudoer = "{0} ALL = (ALL) ALL\n".format(username)
|
||||
fileutil.append_file('/etc/sudoers.d/waagent', sudoer)
|
||||
fileutil.chmod('/etc/sudoers.d/waagent', 0440)
|
||||
|
||||
def DeleteRootPassword(self):
|
||||
def del_root_password(self):
|
||||
try:
|
||||
passwdContent = fileutil.GetFileContents(self.passwdPath)
|
||||
passwd = passwdContent.split('\n')
|
||||
newPasswd = filter(lambda x : not x.startswith("root:"), passwd)
|
||||
newPasswd.insert(0, "root:*LOCK*:14600::::::")
|
||||
fileutil.SetFileContents(self.passwdPath, "\n".join(newPasswd))
|
||||
passwd_content = fileutil.read_file(self.passwd_file_path)
|
||||
passwd = passwd_content.split('\n')
|
||||
new_passwd = filter(lambda x : not x.startswith("root:"), passwd)
|
||||
new_passwd.insert(0, "root:*LOCK*:14600::::::")
|
||||
fileutil.write_file(self.passwd_file_path, "\n".join(new_passwd))
|
||||
except IOError as e:
|
||||
raise OSUtilError("Failed to delete root password:{0}".format(e))
|
||||
|
||||
def GetHome(self):
|
||||
def get_home(self):
|
||||
return self.home
|
||||
|
||||
def GetPubKeyFromPrv(self, fileName):
|
||||
cmd = "{0} rsa -in {1} -pubout 2>/dev/null".format(self.opensslCmd,
|
||||
fileName)
|
||||
pub = shellutil.RunGetOutput(cmd)[1]
|
||||
|
||||
def get_pubkey_from_prv(self, file_name):
|
||||
cmd = "{0} rsa -in {1} -pubout 2>/dev/null".format(self.openssl_cmd,
|
||||
file_name)
|
||||
pub = shellutil.run_get_output(cmd)[1]
|
||||
return pub
|
||||
|
||||
def GetPubKeyFromCrt(self, fileName):
|
||||
cmd = "{0} x509 -in {1} -pubkey -noout".format(self.opensslCmd,
|
||||
fileName)
|
||||
pub = shellutil.RunGetOutput(cmd)[1]
|
||||
def get_pubkey_from_crt(self, file_name):
|
||||
cmd = "{0} x509 -in {1} -pubkey -noout".format(self.openssl_cmd,
|
||||
file_name)
|
||||
pub = shellutil.run_get_output(cmd)[1]
|
||||
return pub
|
||||
|
||||
def _NormPath(self, filepath):
|
||||
home = self.GetHome()
|
||||
def _norm_path(self, filepath):
|
||||
home = self.get_home()
|
||||
# Expand HOME variable if present in path
|
||||
path = os.path.normpath(filepath.replace("$HOME", home))
|
||||
return path
|
||||
|
||||
def GetThumbprintFromCrt(self, fileName):
|
||||
cmd="{0} x509 -in {1} -fingerprint -noout".format(self.opensslCmd,
|
||||
fileName)
|
||||
thumbprint = shellutil.RunGetOutput(cmd)[1]
|
||||
|
||||
def get_thumbprint_from_crt(self, file_name):
|
||||
cmd="{0} x509 -in {1} -fingerprint -noout".format(self.openssl_cmd,
|
||||
file_name)
|
||||
thumbprint = shellutil.run_get_output(cmd)[1]
|
||||
thumbprint = thumbprint.rstrip().split('=')[1].replace(':', '').upper()
|
||||
return thumbprint
|
||||
|
||||
def DeploySshKeyPair(self, userName, thumbprint, path):
|
||||
|
||||
def deploy_ssh_keypair(self, username, thumbprint, path):
|
||||
"""
|
||||
Deploy id_rsa and id_rsa.pub
|
||||
"""
|
||||
path = self._NormPath(path)
|
||||
dirPath = os.path.dirname(path)
|
||||
fileutil.CreateDir(dirPath, mode=0700, owner=userName)
|
||||
libDir = self.GetLibDir()
|
||||
prvPath = os.path.join(libDir, thumbprint + '.prv')
|
||||
if not os.path.isfile(prvPath):
|
||||
logger.Error("Failed to deploy key pair, thumbprint: {0}",
|
||||
path = self._norm_path(path)
|
||||
dir_path = os.path.dirname(path)
|
||||
fileutil.mkdir(dir_path, mode=0700, owner=username)
|
||||
lib_dir = self.get_lib_dir()
|
||||
prv_path = os.path.join(lib_dir, thumbprint + '.prv')
|
||||
if not os.path.isfile(prv_path):
|
||||
logger.error("Failed to deploy key pair, thumbprint: {0}",
|
||||
thumbprint)
|
||||
return
|
||||
shutil.copyfile(prvPath, path)
|
||||
pubPath = path + '.pub'
|
||||
pub = self.GetPubKeyFromPrv(prvPath)
|
||||
fileutil.SetFileContents(pubPath, pub)
|
||||
self.SetSelinuxContext(pubPath, 'unconfined_u:object_r:ssh_home_t:s0')
|
||||
self.SetSelinuxContext(path, 'unconfined_u:object_r:ssh_home_t:s0')
|
||||
shutil.copyfile(prv_path, path)
|
||||
pub_path = path + '.pub'
|
||||
pub = self.get_pubkey_from_prv(prv_path)
|
||||
fileutil.write_file(pub_path, pub)
|
||||
self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0')
|
||||
self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0')
|
||||
os.chmod(path, 0644)
|
||||
os.chmod(pubPath, 0600)
|
||||
os.chmod(pub_path, 0600)
|
||||
|
||||
def OpenSslToOpenSsh(self, inputFile, outputFile):
|
||||
shellutil.Run("ssh-keygen -i -m PKCS8 -f {0} >> {1}".format(inputFile,
|
||||
outputFile))
|
||||
def openssl_to_openssh(self, input_file, output_file):
|
||||
shellutil.run("ssh-keygen -i -m PKCS8 -f {0} >> {1}".format(input_file,
|
||||
output_file))
|
||||
|
||||
def DeploySshPublicKey(self, userName, thumbprint, path):
|
||||
def deploy_ssh_pubkey(self, username, thumbprint, path):
|
||||
"""
|
||||
Deploy authorized_key
|
||||
"""
|
||||
path = self._NormPath(path)
|
||||
dirPath = os.path.dirname(path)
|
||||
fileutil.CreateDir(dirPath, mode=0700, owner=userName)
|
||||
libDir = self.GetLibDir()
|
||||
crtPath = os.path.join(libDir, thumbprint + '.crt')
|
||||
if not os.path.isfile(crtPath):
|
||||
logger.Error("Failed to deploy public key, thumbprint: {0}",
|
||||
path = self._norm_path(path)
|
||||
dir_path = os.path.dirname(path)
|
||||
fileutil.mkdir(dir_path, mode=0700, owner=username)
|
||||
lib_dir = self.get_lib_dir()
|
||||
crt_path = os.path.join(lib_dir, thumbprint + '.crt')
|
||||
if not os.path.isfile(crt_path):
|
||||
logger.error("Failed to deploy public key, thumbprint: {0}",
|
||||
thumbprint)
|
||||
return
|
||||
pubPath = os.path.join(libDir, thumbprint + '.pub')
|
||||
pub = self.GetPubKeyFromCrt(crtPath)
|
||||
fileutil.SetFileContents(pubPath, pub)
|
||||
self.SetSelinuxContext(pubPath, 'unconfined_u:object_r:ssh_home_t:s0')
|
||||
self.OpenSslToOpenSsh(pubPath, path)
|
||||
self.SetSelinuxContext(path, 'unconfined_u:object_r:ssh_home_t:s0')
|
||||
fileutil.ChangeOwner(path, userName)
|
||||
fileutil.ChangeMod(path, 0644)
|
||||
fileutil.ChangeMod(pubPath, 0600)
|
||||
|
||||
def IsSelinuxSystem(self):
|
||||
pub_path = os.path.join(lib_dir, thumbprint + '.pub')
|
||||
pub = self.get_pubkey_from_crt(crt_path)
|
||||
fileutil.write_file(pub_path, pub)
|
||||
self.set_selinux_context(pub_path, 'unconfined_u:object_r:ssh_home_t:s0')
|
||||
self.openssl_to_openssh(pub_path, path)
|
||||
self.set_selinux_context(path, 'unconfined_u:object_r:ssh_home_t:s0')
|
||||
fileutil.chowner(path, username)
|
||||
fileutil.chmod(path, 0644)
|
||||
fileutil.chmod(pub_path, 0600)
|
||||
|
||||
def is_selinux_system(self):
|
||||
"""
|
||||
Checks and sets self.selinux = True if SELinux is available on system.
|
||||
"""
|
||||
if self.selinux == None:
|
||||
if shellutil.Run("which getenforce", chk_err=False) == 0:
|
||||
if shellutil.run("which getenforce", chk_err=False) == 0:
|
||||
self.selinux = True
|
||||
else:
|
||||
self.selinux = False
|
||||
return self.selinux
|
||||
|
||||
def IsSelinuxRunning(self):
|
||||
|
||||
def is_selinux_enforcing(self):
|
||||
"""
|
||||
Calls shell command 'getenforce' and returns True if 'Enforcing'.
|
||||
"""
|
||||
if self.IsSelinuxSystem():
|
||||
output = shellutil.RunGetOutput("getenforce")[1]
|
||||
if self.is_selinux_system():
|
||||
output = shellutil.run_get_output("getenforce")[1]
|
||||
return output.startswith("Enforcing")
|
||||
else:
|
||||
return False
|
||||
|
||||
def SetSelinuxEnforce(self, state):
|
||||
|
||||
def set_selinux_enforce(self, state):
|
||||
"""
|
||||
Calls shell command 'setenforce' with 'state'
|
||||
Calls shell command 'setenforce' with 'state'
|
||||
and returns resulting exit code.
|
||||
"""
|
||||
if self.IsSelinuxSystem():
|
||||
if self.is_selinux_system():
|
||||
if state: s = '1'
|
||||
else: s='0'
|
||||
return shellutil.Run("setenforce "+s)
|
||||
return shellutil.run("setenforce "+s)
|
||||
|
||||
def SetSelinuxContext(self, path, cn):
|
||||
def set_selinux_context(self, path, con):
|
||||
"""
|
||||
Calls shell 'chcon' with 'path' and 'cn' context.
|
||||
Calls shell 'chcon' with 'path' and 'con' context.
|
||||
Returns exit result.
|
||||
"""
|
||||
if self.IsSelinuxSystem():
|
||||
return shellutil.Run('chcon ' + cn + ' ' + path)
|
||||
|
||||
def GetSshdConfigPath(self):
|
||||
return self.sshdConfigPath
|
||||
if self.is_selinux_system():
|
||||
return shellutil.run('chcon ' + con + ' ' + path)
|
||||
|
||||
def SetSshClientAliveInterval(self):
|
||||
configPath = self.GetSshdConfigPath()
|
||||
config = fileutil.GetFileContents(configPath).split("\n")
|
||||
textutil.SetSshConfig(config, "ClientAliveInterval", "180")
|
||||
fileutil.ReplaceFileContentsAtomic(configPath, '\n'.join(config))
|
||||
logger.Info("Configured SSH client probing to keep connections alive.")
|
||||
|
||||
def ConfigSshd(self, disablePassword):
|
||||
option = "no" if disablePassword else "yes"
|
||||
configPath = self.GetSshdConfigPath()
|
||||
config = fileutil.GetFileContents(configPath).split("\n")
|
||||
textutil.SetSshConfig(config, "PasswordAuthentication", option)
|
||||
textutil.SetSshConfig(config, "ChallengeResponseAuthentication", option)
|
||||
fileutil.ReplaceFileContentsAtomic(configPath, "\n".join(config))
|
||||
logger.Info("Disabled SSH password-based authentication methods.")
|
||||
def get_sshd_conf_file_path(self):
|
||||
return self.sshd_conf_file_path
|
||||
|
||||
def set_ssh_client_alive_interval(self):
|
||||
conf_file_path = self.get_sshd_conf_file_path()
|
||||
conf = fileutil.read_file(conf_file_path).split("\n")
|
||||
textutil.set_ssh_config(conf, "ClientAliveInterval", "180")
|
||||
fileutil.replace_file(conf_file_path, '\n'.join(conf))
|
||||
logger.info("Configured SSH client probing to keep connections alive.")
|
||||
|
||||
def conf_sshd(self, disable_password):
|
||||
option = "no" if disable_password else "yes"
|
||||
conf_file_path = self.get_sshd_conf_file_path()
|
||||
conf = fileutil.read_file(conf_file_path).split("\n")
|
||||
textutil.set_ssh_config(conf, "PasswordAuthentication", option)
|
||||
textutil.set_ssh_config(conf, "ChallengeResponseAuthentication", option)
|
||||
fileutil.replace_file(conf_file_path, "\n".join(conf))
|
||||
logger.info("Disabled SSH password-based authentication methods.")
|
||||
|
||||
|
||||
def GetDvdDevice(self, devDir='/dev'):
|
||||
def get_dvd_device(self, dev_dir='/dev'):
|
||||
patten=r'(sr[0-9]|hd[c-z]|cdrom[0-9])'
|
||||
for dvd in [re.match(patten, dev) for dev in os.listdir(devDir)]:
|
||||
for dvd in [re.match(patten, dev) for dev in os.listdir(dev_dir)]:
|
||||
if dvd is not None:
|
||||
return "/dev/{0}".format(dvd.group(0))
|
||||
raise OSUtilError("Failed to get dvd device")
|
||||
|
||||
def MountDvd(self, maxRetry=6, chk_err=True):
|
||||
dvd = self.GetDvdDevice()
|
||||
mountPoint = self.GetDvdMountPoint()
|
||||
mountlist = shellutil.RunGetOutput("mount")[1]
|
||||
existing = self.GetMountPoint(mountlist, dvd)
|
||||
def mount_dvd(self, max_retry=6, chk_err=True):
|
||||
dvd = self.get_dvd_device()
|
||||
mount_point = self.get_dvd_mount_point()
|
||||
mountlist = shellutil.run_get_output("mount")[1]
|
||||
existing = self.get_mount_point(mountlist, dvd)
|
||||
if existing is not None: #Already mounted
|
||||
logger.Info("{0} is already mounted at {1}", dvd, existing)
|
||||
logger.info("{0} is already mounted at {1}", dvd, existing)
|
||||
return
|
||||
if not os.path.isdir(mountPoint):
|
||||
os.makedirs(mountPoint)
|
||||
|
||||
for retry in range(0, maxRetry):
|
||||
retcode = self.Mount(dvd, mountPoint, option="-o ro -t iso9660,udf",
|
||||
if not os.path.isdir(mount_point):
|
||||
os.makedirs(mount_point)
|
||||
|
||||
for retry in range(0, max_retry):
|
||||
retcode = self.mount(dvd, mount_point, option="-o ro -t iso9660,udf",
|
||||
chk_err=chk_err)
|
||||
if retcode == 0:
|
||||
logger.Info("Successfully mounted dvd")
|
||||
logger.info("Successfully mounted dvd")
|
||||
return
|
||||
if retry < maxRetry - 1:
|
||||
logger.Warn("Mount dvd failed: retry={0}, ret={1}", retry,
|
||||
if retry < max_retry - 1:
|
||||
logger.warn("Mount dvd failed: retry={0}, ret={1}", retry,
|
||||
retcode)
|
||||
time.sleep(5)
|
||||
if chk_err:
|
||||
raise OSUtilError("Failed to mount dvd.")
|
||||
|
||||
def UmountDvd(self, chk_err=True):
|
||||
mountPoint = self.GetDvdMountPoint()
|
||||
retcode = self.Umount(mountPoint, chk_err=chk_err)
|
||||
def umount_dvd(self, chk_err=True):
|
||||
mount_point = self.get_dvd_mount_point()
|
||||
retcode = self.umount(mount_point, chk_err=chk_err)
|
||||
if chk_err and retcode != 0:
|
||||
raise OSUtilError("Failed to umount dvd.")
|
||||
|
||||
def LoadAtapiixModule(self):
|
||||
if self.IsAtaPiixModuleLoaded():
|
||||
def load_atappix_mod(self):
|
||||
if self.is_atapiix_mod_loaded():
|
||||
return
|
||||
ret, kernVersion = shellutil.RunGetOutput("uname -r")
|
||||
ret, kern_version = shellutil.run_get_output("uname -r")
|
||||
if ret != 0:
|
||||
raise Exception("Failed to call uname -r")
|
||||
modulePath = os.path.join('/lib/modules',
|
||||
kernVersion.strip('\n'),
|
||||
'kernel/drivers/ata/ata_piix.ko')
|
||||
if not os.path.isfile(modulePath):
|
||||
raise Exception("Can't find module file:{0}".format(modulePath))
|
||||
mod_path = os.path.join('/lib/modules',
|
||||
kern_version.strip('\n'),
|
||||
'kernel/drivers/ata/ata_piix.ko')
|
||||
if not os.path.isfile(mod_path):
|
||||
raise Exception("Can't find module file:{0}".format(mod_path))
|
||||
|
||||
ret, output = shellutil.RunGetOutput("insmod " + modulePath)
|
||||
ret, output = shellutil.run_get_output("insmod " + mod_path)
|
||||
if ret != 0:
|
||||
raise Exception("Error calling insmod for ATAPI CD-ROM driver")
|
||||
if not self.IsAtaPiixModuleLoaded(maxRetry=3):
|
||||
raise Exception("Failed to load ATAPI CD-ROM driver")
|
||||
if not self.is_atapiix_mod_loaded(max_retry=3):
|
||||
raise Exception("Failed to load ATAPI CD-ROM driver")
|
||||
|
||||
def IsAtaPiixModuleLoaded(self, maxRetry=1):
|
||||
for retry in range(0, maxRetry):
|
||||
ret = shellutil.Run("lsmod | grep ata_piix", chk_err=False)
|
||||
def is_atapiix_mod_loaded(self, max_retry=1):
|
||||
for retry in range(0, max_retry):
|
||||
ret = shellutil.run("lsmod | grep ata_piix", chk_err=False)
|
||||
if ret == 0:
|
||||
logger.Info("Module driver for ATAPI CD-ROM is already present.")
|
||||
logger.info("Module driver for ATAPI CD-ROM is already present.")
|
||||
return True
|
||||
if retry < maxRetry - 1:
|
||||
if retry < max_retry - 1:
|
||||
time.sleep(1)
|
||||
return False
|
||||
|
||||
def Mount(self, dvd, mountPoint, option="", chk_err=True):
|
||||
cmd = "mount {0} {1} {2}".format(dvd, option, mountPoint)
|
||||
return shellutil.RunGetOutput(cmd, chk_err)[0]
|
||||
|
||||
def Umount(self, mountPoint, chk_err=True):
|
||||
return shellutil.Run("umount {0}".format(mountPoint), chk_err=chk_err)
|
||||
def mount(self, dvd, mount_point, option="", chk_err=True):
|
||||
cmd = "mount {0} {1} {2}".format(dvd, option, mount_point)
|
||||
return shellutil.run_get_output(cmd, chk_err)[0]
|
||||
|
||||
def OpenPortForDhcp(self):
|
||||
def umount(self, mount_point, chk_err=True):
|
||||
return shellutil.run("umount {0}".format(mount_point), chk_err=chk_err)
|
||||
|
||||
def allow_dhcp_broadcast(self):
|
||||
#Open DHCP port if iptables is enabled.
|
||||
# We supress error logging on error.
|
||||
shellutil.Run("iptables -D INPUT -p udp --dport 68 -j ACCEPT",
|
||||
chk_err=False)
|
||||
shellutil.Run("iptables -I INPUT -p udp --dport 68 -j ACCEPT",
|
||||
shellutil.run("iptables -D INPUT -p udp --dport 68 -j ACCEPT",
|
||||
chk_err=False)
|
||||
shellutil.run("iptables -I INPUT -p udp --dport 68 -j ACCEPT",
|
||||
chk_err=False)
|
||||
|
||||
def GenerateTransportCert(self):
|
||||
def gen_transport_cert(self):
|
||||
"""
|
||||
Create ssl certificate for https communication with endpoint server.
|
||||
"""
|
||||
cmd = ("{0} req -x509 -nodes -subj /CN=LinuxTransport -days 32768 "
|
||||
"-newkey rsa:2048 -keyout TransportPrivate.pem "
|
||||
"-out TransportCert.pem").format(self.opensslCmd)
|
||||
shellutil.Run(cmd)
|
||||
"-out TransportCert.pem").format(self.openssl_cmd)
|
||||
shellutil.run(cmd)
|
||||
|
||||
def RemoveRulesFiles(self, rulesFiles=RulesFiles):
|
||||
libDir = self.GetLibDir()
|
||||
for src in rulesFiles:
|
||||
fileName = fileutil.GetLastPathElement(src)
|
||||
dest = os.path.join(libDir, fileName)
|
||||
def remove_rules_files(self, rules_files=__RULES_FILES__):
|
||||
lib_dir = self.get_lib_dir()
|
||||
for src in rules_files:
|
||||
file_name = fileutil.base_name(src)
|
||||
dest = os.path.join(lib_dir, file_name)
|
||||
if os.path.isfile(dest):
|
||||
os.remove(dest)
|
||||
if os.path.isfile(src):
|
||||
logger.Warn("Move rules file {0} to {1}", fileName, dest)
|
||||
logger.warn("Move rules file {0} to {1}", file_name, dest)
|
||||
shutil.move(src, dest)
|
||||
|
||||
def RestoreRulesFiles(self, rulesFiles=RulesFiles):
|
||||
libDir = self.GetLibDir()
|
||||
for dest in rulesFiles:
|
||||
fileName = fileutil.GetLastPathElement(dest)
|
||||
src = os.path.join(libDir, fileName)
|
||||
def restore_rules_files(self, rules_files=__RULES_FILES__):
|
||||
lib_dir = self.get_lib_dir()
|
||||
for dest in rules_files:
|
||||
filename = fileutil.base_name(dest)
|
||||
src = os.path.join(lib_dir, filename)
|
||||
if os.path.isfile(dest):
|
||||
continue
|
||||
if os.path.isfile(src):
|
||||
logger.Warn("Move rules file {0} to {1}", fileName, dest)
|
||||
logger.warn("Move rules file {0} to {1}", filename, dest)
|
||||
shutil.move(src, dest)
|
||||
|
||||
def GetMacAddress(self):
|
||||
def get_mac_addr(self):
|
||||
"""
|
||||
Convienience function, returns mac addr bound to
|
||||
first non-loobback interface.
|
||||
"""
|
||||
ifname=''
|
||||
while len(ifname) < 2 :
|
||||
ifname=self.GetFirstActiveNetworkInterfaceNonLoopback()[0]
|
||||
addr = self.GetInterfaceMac(ifname)
|
||||
return textutil.HexStringToByteArray(addr)
|
||||
ifname=self.get_first_if()[0]
|
||||
addr = self.get_if_mac(ifname)
|
||||
return textutil.hexstr_to_bytearray(addr)
|
||||
|
||||
def GetInterfaceMac(self, ifname):
|
||||
def get_if_mac(self, ifname):
|
||||
"""
|
||||
Return the mac-address bound to the socket.
|
||||
"""
|
||||
sock = socket.socket(socket.AF_INET,
|
||||
socket.SOCK_DGRAM,
|
||||
sock = socket.socket(socket.AF_INET,
|
||||
socket.SOCK_DGRAM,
|
||||
socket.IPPROTO_UDP)
|
||||
param = struct.pack('256s', (ifname[:15]+('\0'*241)).encode('latin-1'))
|
||||
info = fcntl.ioctl(sock.fileno(), 0x8927, param)
|
||||
return ''.join(['%02X' % textutil.Ord(char) for char in info[18:24]])
|
||||
return ''.join(['%02X' % textutil.str_to_ord(char) for char in info[18:24]])
|
||||
|
||||
def GetFirstActiveNetworkInterfaceNonLoopback(self):
|
||||
def get_first_if(self):
|
||||
"""
|
||||
Return the interface name, and ip addr of the
|
||||
first active non-loopback interface.
|
||||
@@ -446,17 +446,17 @@ class DefaultOSUtil(object):
|
||||
iface=''
|
||||
expected=16 # how many devices should I expect...
|
||||
struct_size=40 # for 64bit the size is 40 bytes
|
||||
sock = socket.socket(socket.AF_INET,
|
||||
socket.SOCK_DGRAM,
|
||||
sock = socket.socket(socket.AF_INET,
|
||||
socket.SOCK_DGRAM,
|
||||
socket.IPPROTO_UDP)
|
||||
buff=array.array('B', b'\0' * (expected * struct_size))
|
||||
param = struct.pack('iL',
|
||||
expected*struct_size,
|
||||
param = struct.pack('iL',
|
||||
expected*struct_size,
|
||||
buff.buffer_info()[0])
|
||||
ret = fcntl.ioctl(sock.fileno(), 0x8912, param)
|
||||
retsize=(struct.unpack('iL', ret)[0])
|
||||
if retsize == (expected * struct_size):
|
||||
logger.Warn(('SIOCGIFCONF returned more than {0} up '
|
||||
logger.warn(('SIOCGIFCONF returned more than {0} up '
|
||||
'network interfaces.'), expected)
|
||||
sock = buff.tostring()
|
||||
for i in range(0, struct_size * expected, struct_size):
|
||||
@@ -467,114 +467,114 @@ class DefaultOSUtil(object):
|
||||
break
|
||||
return iface.decode('latin-1'), socket.inet_ntoa(sock[i+20:i+24])
|
||||
|
||||
def IsMissingDefaultRoute(self):
|
||||
routes = shellutil.RunGetOutput("route -n")[1]
|
||||
def is_missing_default_route(self):
|
||||
routes = shellutil.run_get_output("route -n")[1]
|
||||
for route in routes:
|
||||
if route.startswith("0.0.0.0 ") or route.startswith("default "):
|
||||
return False
|
||||
return False
|
||||
return True
|
||||
|
||||
def GetInterfaceName(self):
|
||||
return self.GetFirstActiveNetworkInterfaceNonLoopback()[0]
|
||||
|
||||
def GetIpv4Address(self):
|
||||
return self.GetFirstActiveNetworkInterfaceNonLoopback()[1]
|
||||
|
||||
def SetBroadcastRouteForDhcp(self, ifname):
|
||||
return shellutil.Run("route add 255.255.255.255 dev {0}".format(ifname),
|
||||
def get_if_name(self):
|
||||
return self.get_first_if()[0]
|
||||
|
||||
def get_ip4_addr(self):
|
||||
return self.get_first_if()[1]
|
||||
|
||||
def set_route_for_dhcp_broadcast(self, ifname):
|
||||
return shellutil.run("route add 255.255.255.255 dev {0}".format(ifname),
|
||||
chk_err=False)
|
||||
|
||||
def RemoveBroadcastRouteForDhcp(self, ifname):
|
||||
shellutil.Run("route del 255.255.255.255 dev {0}".format(ifname),
|
||||
def remove_route_for_dhcp_broadcast(self, ifname):
|
||||
shellutil.run("route del 255.255.255.255 dev {0}".format(ifname),
|
||||
chk_err=False)
|
||||
|
||||
def IsDhcpEnabled(self):
|
||||
def is_dhcp_enabled(self):
|
||||
return False
|
||||
|
||||
def StopDhcpService(self):
|
||||
def stop_dhcp_service(self):
|
||||
pass
|
||||
|
||||
def StartDhcpService(self):
|
||||
def start_dhcp_service(self):
|
||||
pass
|
||||
|
||||
def StartNetwork(self):
|
||||
def start_network(self):
|
||||
pass
|
||||
|
||||
def StartAgentService(self):
|
||||
def start_agent_service(self):
|
||||
pass
|
||||
|
||||
def StopAgentService(self):
|
||||
def stop_agent_service(self):
|
||||
pass
|
||||
|
||||
def RegisterAgentService(self):
|
||||
def register_agent_service(self):
|
||||
pass
|
||||
|
||||
def UnregisterAgentService(self):
|
||||
def unregister_agent_service(self):
|
||||
pass
|
||||
|
||||
def RestartSshService(self):
|
||||
def restart_ssh_service(self):
|
||||
pass
|
||||
|
||||
def RouteAdd(self, net, mask, gateway):
|
||||
|
||||
def route_add(self, net, mask, gateway):
|
||||
"""
|
||||
Add specified route using /sbin/route add -net.
|
||||
"""
|
||||
cmd = ("/sbin/route add -net "
|
||||
"{0} netmask {1} gw {2}").format(net, mask, gateway)
|
||||
return shellutil.Run(cmd, chk_err=False)
|
||||
return shellutil.run(cmd, chk_err=False)
|
||||
|
||||
def GetDhcpProcessId(self):
|
||||
ret= shellutil.RunGetOutput("pidof dhclient")
|
||||
def get_dhcp_pid(self):
|
||||
ret= shellutil.run_get_output("pidof dhclient")
|
||||
return ret[1] if ret[0] == 0 else None
|
||||
|
||||
def SetHostname(self, hostname):
|
||||
fileutil.SetFileContents('/etc/hostname', hostname)
|
||||
shellutil.Run("hostname {0}".format(hostname), chk_err=False)
|
||||
def set_hostname(self, hostname):
|
||||
fileutil.write_file('/etc/hostname', hostname)
|
||||
shellutil.run("hostname {0}".format(hostname), chk_err=False)
|
||||
|
||||
def SetDhcpHostname(self, hostname):
|
||||
autoSend = r'^[^#]*?send\s*host-name.*?(<hostname>|gethostname[(,)])'
|
||||
dhclientFiles = ['/etc/dhcp/dhclient.conf', '/etc/dhcp3/dhclient.conf']
|
||||
for confFile in dhclientFiles:
|
||||
if not os.path.isfile(confFile):
|
||||
def set_dhcp_hostname(self, hostname):
|
||||
autosend = r'^[^#]*?send\s*host-name.*?(<hostname>|gethostname[(,)])'
|
||||
dhclient_files = ['/etc/dhcp/dhclient.conf', '/etc/dhcp3/dhclient.conf']
|
||||
for conf_file in dhclient_files:
|
||||
if not os.path.isfile(conf_file):
|
||||
continue
|
||||
if fileutil.FindStringInFile(confFile, autoSend):
|
||||
if fileutil.findstr_in_file(conf_file, autosend):
|
||||
#Return if auto send host-name is configured
|
||||
return
|
||||
fileutil.UpdateConfigFile(confFile,
|
||||
fileutil.update_conf_file(conf_file,
|
||||
'send host-name',
|
||||
'send host-name {0}'.format(hostname))
|
||||
|
||||
def RestartInterface(self, ifname):
|
||||
shellutil.Run("ifdown {0} && ifup {1}".format(ifname, ifname))
|
||||
def restart_if(self, ifname):
|
||||
shellutil.run("ifdown {0} && ifup {1}".format(ifname, ifname))
|
||||
|
||||
def PublishHostname(self, hostname):
|
||||
self.SetDhcpHostname(hostname)
|
||||
ifname = self.GetInterfaceName()
|
||||
self.RestartInterface(ifname)
|
||||
def publish_hostname(self, hostname):
|
||||
self.set_dhcp_hostname(hostname)
|
||||
ifname = self.get_if_name()
|
||||
self.restart_if(ifname)
|
||||
|
||||
def SetScsiDiskTimeout(self, timeout):
|
||||
def set_scsi_disks_timeout(self, timeout):
|
||||
for dev in os.listdir("/sys/block"):
|
||||
if dev.startswith('sd'):
|
||||
self.SetBlockDeviceTimeout(dev, timeout)
|
||||
self.set_block_device_timeout(dev, timeout)
|
||||
|
||||
def SetBlockDeviceTimeout(self, dev, timeout):
|
||||
def set_block_device_timeout(self, dev, timeout):
|
||||
if dev is not None and timeout is not None:
|
||||
filePath = "/sys/block/{0}/device/timeout".format(dev)
|
||||
content = fileutil.GetFileContents(filePath)
|
||||
file_path = "/sys/block/{0}/device/timeout".format(dev)
|
||||
content = fileutil.read_file(file_path)
|
||||
original = content.splitlines()[0].rstrip()
|
||||
if original != timeout:
|
||||
fileutil.SetFileContents(filePath, timeout)
|
||||
logger.Info("Set block dev timeout: {0} with timeout: {1}",
|
||||
fileutil.write_file(file_path, timeout)
|
||||
logger.info("Set block dev timeout: {0} with timeout: {1}",
|
||||
dev, timeout)
|
||||
|
||||
def GetMountPoint(self, mountlist, device):
|
||||
def get_mount_point(self, mountlist, device):
|
||||
"""
|
||||
Example of mountlist:
|
||||
/dev/sda1 on / type ext4 (rw)
|
||||
proc on /proc type proc (rw)
|
||||
sysfs on /sys type sysfs (rw)
|
||||
devpts on /dev/pts type devpts (rw,gid=5,mode=620)
|
||||
tmpfs on /dev/shm type tmpfs
|
||||
tmpfs on /dev/shm type tmpfs
|
||||
(rw,rootcontext="system_u:object_r:tmpfs_t:s0")
|
||||
none on /proc/sys/fs/binfmt_misc type binfmt_misc (rw)
|
||||
/dev/sdb1 on /mnt/resource type ext4 (rw)
|
||||
@@ -586,24 +586,23 @@ class DefaultOSUtil(object):
|
||||
#Return the 3rd column of this line
|
||||
return tokens[2] if len(tokens) > 2 else None
|
||||
return None
|
||||
|
||||
def DeviceForIdePort(self, n):
|
||||
|
||||
def device_for_ide_port(self, port_id):
|
||||
"""
|
||||
Return device name attached to ide port 'n'.
|
||||
"""
|
||||
if n > 3:
|
||||
if port_id > 3:
|
||||
return None
|
||||
g0 = "00000000"
|
||||
if n > 1:
|
||||
if port_id > 1:
|
||||
g0 = "00000001"
|
||||
n = n - 2
|
||||
port_id = port_id - 2
|
||||
device = None
|
||||
path = "/sys/bus/vmbus/devices/"
|
||||
for vmbus in os.listdir(path):
|
||||
deviceid = fileutil.GetFileContents(os.path.join(path, vmbus,
|
||||
"device_id"))
|
||||
deviceid = fileutil.read_file(os.path.join(path, vmbus, "device_id"))
|
||||
guid = deviceid.lstrip('{').split('-')
|
||||
if guid[0] == g0 and guid[1] == "000" + str(n):
|
||||
if guid[0] == g0 and guid[1] == "000" + str(port_id):
|
||||
for root, dirs, files in os.walk(path + vmbus):
|
||||
if root.endswith("/block"):
|
||||
device = dirs[0]
|
||||
@@ -616,38 +615,37 @@ class DefaultOSUtil(object):
|
||||
break
|
||||
return device
|
||||
|
||||
def DeleteAccount(self, userName):
|
||||
if self.IsSysUser(userName):
|
||||
logger.Error("{0} is a system user. Will not delete it.", userName)
|
||||
shellutil.Run("> /var/run/utmp")
|
||||
shellutil.Run("userdel -f -r " + userName)
|
||||
def del_account(self, username):
|
||||
if self.is_sys_user(username):
|
||||
logger.error("{0} is a system user. Will not delete it.", username)
|
||||
shellutil.run("> /var/run/utmp")
|
||||
shellutil.run("userdel -f -r " + username)
|
||||
#Remove user from suders
|
||||
if os.path.isfile("/etc/suders.d/waagent"):
|
||||
try:
|
||||
content = fileutil.GetFileContents("/etc/sudoers.d/waagent")
|
||||
content = fileutil.read_file("/etc/sudoers.d/waagent")
|
||||
sudoers = content.split("\n")
|
||||
sudoers = filter(lambda x : userName not in x, sudoers)
|
||||
fileutil.SetFileContents("/etc/sudoers.d/waagent",
|
||||
sudoers = filter(lambda x : username not in x, sudoers)
|
||||
fileutil.write_file("/etc/sudoers.d/waagent",
|
||||
"\n".join(sudoers))
|
||||
except IOError as e:
|
||||
raise OSUtilError("Failed to remove sudoer: {0}".format(e))
|
||||
|
||||
def TranslateCustomData(self, data):
|
||||
|
||||
def decode_customdata(self, data):
|
||||
return data
|
||||
|
||||
def GetTotalMemory(self):
|
||||
def get_total_mem(self):
|
||||
cmd = "grep MemTotal /proc/meminfo |awk '{print $2}'"
|
||||
ret = shellutil.RunGetOutput(cmd)
|
||||
ret = shellutil.run_get_output(cmd)
|
||||
if ret[0] == 0:
|
||||
return int(ret[1])/1024
|
||||
else:
|
||||
raise OSUtilError("Failed to get total memory: {0}".format(ret[1]))
|
||||
|
||||
def GetProcessorCores(self):
|
||||
ret = shellutil.RunGetOutput("grep 'processor.*:' /proc/cpuinfo |wc -l")
|
||||
def get_processor_cores(self):
|
||||
ret = shellutil.run_get_output("grep 'processor.*:' /proc/cpuinfo |wc -l")
|
||||
if ret[0] == 0:
|
||||
return int(ret[1])
|
||||
else:
|
||||
raise OSUtilError("Failed to get procerssor cores")
|
||||
|
||||
OSUtil = DefaultOSUtil
|
||||
|
||||
@@ -18,133 +18,135 @@
|
||||
import os
|
||||
import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.conf as conf
|
||||
from azurelinuxagent.event import AddExtensionEvent, WALAEventOperation
|
||||
from azurelinuxagent.event import add_event, WALAEventOperation
|
||||
from azurelinuxagent.exception import *
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
import azurelinuxagent.protocol as prot
|
||||
import azurelinuxagent.protocol.ovfenv as ovf
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
|
||||
CustomDataFile="CustomData"
|
||||
CUSTOM_DATA_FILE="CustomData"
|
||||
|
||||
class ProvisionHandler(object):
|
||||
|
||||
def process(self):
|
||||
#If provision is not enabled, return
|
||||
if not conf.GetSwitch("Provisioning.Enabled", True):
|
||||
logger.Info("Provisioning is disabled. Skip.")
|
||||
if not conf.get_switch("Provisioning.Enabled", True):
|
||||
logger.info("Provisioning is disabled. Skip.")
|
||||
return
|
||||
|
||||
provisioned = os.path.join(OSUtil.GetLibDir(), "provisioned")
|
||||
provisioned = os.path.join(OSUTIL.get_lib_dir(), "provisioned")
|
||||
if os.path.isfile(provisioned):
|
||||
return
|
||||
|
||||
logger.Info("Run provision handler.")
|
||||
protocol = prot.Factory.getDefaultProtocol()
|
||||
|
||||
logger.info("run provision handler.")
|
||||
protocol = prot.Factory.get_default_protocol()
|
||||
try:
|
||||
status = prot.ProvisionStatus(status="NotReady",
|
||||
subStatus="ProvisionStatus")
|
||||
protocol.reportProvisionStatus(status)
|
||||
status = prot.ProvisionStatus(status="NotReady",
|
||||
subStatus="Provision started")
|
||||
protocol.report_provision_status(status)
|
||||
|
||||
self.provision()
|
||||
fileutil.SetFileContents(provisioned, "")
|
||||
thumbprint = self.regenerateSshHostKey()
|
||||
|
||||
logger.Info("Finished provisioning")
|
||||
fileutil.write_file(provisioned, "")
|
||||
thumbprint = self.reg_ssh_host_key()
|
||||
|
||||
logger.info("Finished provisioning")
|
||||
status = prot.ProvisionStatus(status="Ready")
|
||||
status.properties.certificateThumbprint = thumbprint
|
||||
protocol.reportProvisionStatus(status)
|
||||
protocol.report_provision_status(status)
|
||||
|
||||
AddExtensionEvent(name="WALA", isSuccess=True, message="",
|
||||
add_event(name="WALA", is_success=True, message="",
|
||||
op=WALAEventOperation.Provision)
|
||||
except ProvisionError as e:
|
||||
logger.Error("Provision failed: {0}", e)
|
||||
protocol.reportProvisionStatus(status="NotReady", subStatus=str(e))
|
||||
AddExtensionEvent(name="WALA", isSuccess=False, message=str(e),
|
||||
logger.error("Provision failed: {0}", e)
|
||||
status = prot.ProvisionStatus(status="NotReady",
|
||||
subStatus= str(e))
|
||||
protocol.report_provision_status(status)
|
||||
add_event(name="WALA", is_success=False, message=str(e),
|
||||
op=WALAEventOperation.Provision)
|
||||
|
||||
|
||||
def regenerateSshHostKey(self):
|
||||
keyPairType = conf.Get("Provisioning.SshHostKeyPairType", "rsa")
|
||||
if conf.GetSwitch("Provisioning.RegenerateSshHostKeyPair"):
|
||||
shellutil.Run("rm -f /etc/ssh/ssh_host_*key*")
|
||||
shellutil.Run(("ssh-keygen -N '' -t {0} -f /etc/ssh/ssh_host_{1}_key"
|
||||
"").format(keyPairType, keyPairType))
|
||||
thumbprint = self.getSshHostKeyThumbprint(keyPairType)
|
||||
|
||||
def reg_ssh_host_key(self):
|
||||
keypair_type = conf.get("Provisioning.SshHostKeyPairType", "rsa")
|
||||
if conf.get_switch("Provisioning.RegenerateSshHostKeyPair"):
|
||||
shellutil.run("rm -f /etc/ssh/ssh_host_*key*")
|
||||
shellutil.run(("ssh-keygen -N '' -t {0} -f /etc/ssh/ssh_host_{1}_key"
|
||||
"").format(keypair_type, keypair_type))
|
||||
thumbprint = self.get_ssh_host_key_thumbprint(keypair_type)
|
||||
return thumbprint
|
||||
|
||||
def getSshHostKeyThumbprint(self, keyPairType):
|
||||
cmd = "ssh-keygen -lf /etc/ssh/ssh_host_{0}_key.pub".format(keyPairType)
|
||||
ret = shellutil.RunGetOutput(cmd)
|
||||
def get_ssh_host_key_thumbprint(self, keypair_type):
|
||||
cmd = "ssh-keygen -lf /etc/ssh/ssh_host_{0}_key.pub".format(keypair_type)
|
||||
ret = shellutil.run_get_output(cmd)
|
||||
if ret[0] == 0:
|
||||
return ret[1].rstrip().split()[1].replace(':', '')
|
||||
else:
|
||||
raise ProvisionError(("Failed to generate ssh host key: "
|
||||
"ret={0}, out= {1}").format(ret[0], ret[1]))
|
||||
|
||||
|
||||
|
||||
def provision(self):
|
||||
logger.Info("Copy ovf-env.xml.")
|
||||
logger.info("Copy ovf-env.xml.")
|
||||
try:
|
||||
ovfenv = ovf.CopyOvfEnv()
|
||||
ovfenv = ovf.copy_ovf_env()
|
||||
except prot.ProtocolError as e:
|
||||
raise ProvisionError("Failed to copy ovf-env.xml: {0}".format(e))
|
||||
|
||||
password = ovfenv.getUserPassword()
|
||||
ovfenv.clearUserPassword()
|
||||
|
||||
logger.Info("Set host name.")
|
||||
OSUtil.SetHostname(ovfenv.getComputerName())
|
||||
logger.Info("Publish host name.")
|
||||
OSUtil.PublishHostname(ovfenv.getComputerName())
|
||||
logger.Info("Create user account.")
|
||||
OSUtil.UpdateUserAccount(ovfenv.getUserName(), password)
|
||||
password = ovfenv.get_user_password()
|
||||
ovfenv.clear_user_password()
|
||||
|
||||
logger.info("Set host name.")
|
||||
OSUTIL.set_hostname(ovfenv.get_computer_name())
|
||||
logger.info("Publish host name.")
|
||||
OSUTIL.publish_hostname(ovfenv.get_computer_name())
|
||||
logger.info("Create user account.")
|
||||
OSUTIL.set_user_account(ovfenv.get_username(), password)
|
||||
|
||||
if password is not None:
|
||||
userSalt = conf.GetSwitch("Provision.UseSalt", True)
|
||||
saltType = conf.GetSwitch("Provision.SaltType", 6)
|
||||
logger.Info("Set user password.")
|
||||
OSUtil.ChangePassword(ovfenv.getUserName(), password, userSalt,
|
||||
saltType)
|
||||
use_salt = conf.get_switch("Provision.UseSalt", True)
|
||||
salt_type = conf.get_switch("Provision.SaltType", 6)
|
||||
logger.info("Set user password.")
|
||||
OSUTIL.chpasswd(ovfenv.get_username(), password, use_salt,
|
||||
salt_type)
|
||||
|
||||
logger.Info("Configure sshd.")
|
||||
OSUtil.ConfigSshd(ovfenv.getDisableSshPasswordAuthentication())
|
||||
logger.info("Configure sshd.")
|
||||
OSUTIL.conf_sshd(ovfenv.get_disable_ssh_password_auth())
|
||||
|
||||
#Disable selinux temporary
|
||||
sel = OSUtil.IsSelinuxRunning()
|
||||
sel = OSUTIL.is_selinux_enforcing()
|
||||
if sel:
|
||||
OSUtil.SetSelinuxEnforce(0)
|
||||
|
||||
self.deploySshPublicKeys(ovfenv)
|
||||
self.deploySshKeyPairs(ovfenv)
|
||||
self.saveCustomData(ovfenv)
|
||||
OSUTIL.set_selinux_enforce(0)
|
||||
|
||||
self.deploy_ssh_pubkeys(ovfenv)
|
||||
self.deploy_ssh_keypairs(ovfenv)
|
||||
self.save_customdata(ovfenv)
|
||||
|
||||
if sel:
|
||||
OSUtil.SetSelinuxEnforce(1)
|
||||
OSUTIL.set_selinux_enforce(1)
|
||||
|
||||
OSUtil.RestartSshService()
|
||||
OSUTIL.restart_ssh_service()
|
||||
|
||||
if conf.GetSwitch("Provisioning.DeleteRootPassword"):
|
||||
OSUtil.DeleteRootPassword()
|
||||
if conf.get_switch("Provisioning.DeleteRootPassword"):
|
||||
OSUTIL.del_root_password()
|
||||
|
||||
|
||||
def saveCustomData(self, ovfenv):
|
||||
logger.Info("Save custom data")
|
||||
customData = ovfenv.getCustomData()
|
||||
if customData is None:
|
||||
def save_customdata(self, ovfenv):
|
||||
logger.info("Save custom data")
|
||||
customdata = ovfenv.get_customdata()
|
||||
if customdata is None:
|
||||
return
|
||||
libDir = OSUtil.GetLibDir()
|
||||
fileutil.SetFileContents(os.path.join(libDir, CustomDataFile),
|
||||
OSUtil.TranslateCustomData(customData))
|
||||
lib_dir = OSUTIL.get_lib_dir()
|
||||
fileutil.write_file(os.path.join(lib_dir, CUSTOM_DATA_FILE),
|
||||
OSUTIL.decode_customdata(customdata))
|
||||
|
||||
def deploy_ssh_pubkeys(self, ovfenv):
|
||||
for thumbprint, path in ovfenv.get_ssh_pubkeys():
|
||||
logger.info("Deploy ssh public key.")
|
||||
OSUTIL.deploy_ssh_pubkey(ovfenv.get_username(), thumbprint, path)
|
||||
|
||||
def deploy_ssh_keypairs(self, ovfenv):
|
||||
for thumbprint, path in ovfenv.get_ssh_keypairs():
|
||||
logger.info("Deploy ssh key pairs.")
|
||||
OSUTIL.deploy_ssh_keypair(ovfenv.get_username(), thumbprint, path)
|
||||
|
||||
def deploySshPublicKeys(self, ovfenv):
|
||||
for thumbprint, path in ovfenv.getSshPublicKeys():
|
||||
logger.Info("Deploy ssh public key.")
|
||||
OSUtil.DeploySshPublicKey(ovfenv.getUserName(), thumbprint, path)
|
||||
|
||||
def deploySshKeyPairs(self, ovfenv):
|
||||
for thumbprint, path in ovfenv.getSshKeyPairs():
|
||||
logger.Info("Deploy ssh key pairs.")
|
||||
OSUtil.DeploySshKeyPair(ovfenv.getUserName(), thumbprint, path)
|
||||
|
||||
|
||||
@@ -22,15 +22,15 @@ import re
|
||||
import threading
|
||||
import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.conf as conf
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.event import AddExtensionEvent, WALAEventOperation
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
from azurelinuxagent.event import add_event, WALAEventOperation
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
from azurelinuxagent.exception import ResourceDiskError
|
||||
|
||||
DataLossWarningFile="DATALOSS_WARNING_README.txt"
|
||||
DataLossWarning="""\
|
||||
WARNING: THIS IS A TEMPORARY DISK.
|
||||
DATALOSS_WARNING_FILE_NAME="DATALOSS_WARNING_README.txt"
|
||||
DATA_LOSS_WARNING="""\
|
||||
WARNING: THIS IS A TEMPORARY DISK.
|
||||
|
||||
Any data stored on this drive is SUBJECT TO LOSS and THERE IS NO WAY TO RECOVER IT.
|
||||
|
||||
@@ -41,126 +41,126 @@ For additional details to please refer to the MSDN documentation at : http://msd
|
||||
|
||||
class ResourceDiskHandler(object):
|
||||
|
||||
def startActivateResourceDisk(self):
|
||||
diskThread = threading.Thread(target = self.run)
|
||||
diskThread.start()
|
||||
|
||||
def start_activate_resource_disk(self):
|
||||
disk_thread = threading.Thread(target = self.run)
|
||||
disk_thread.start()
|
||||
|
||||
def run(self):
|
||||
mountpoint = None
|
||||
if conf.GetSwitch("ResourceDisk.Format", False):
|
||||
mountpoint = self.activateResourceDisk()
|
||||
if mountpoint is not None and \
|
||||
conf.GetSwitch("ResourceDisk.EnableSwap", False):
|
||||
self.enableSwap(mountpoint)
|
||||
mount_point = None
|
||||
if conf.get_switch("ResourceDisk.Format", False):
|
||||
mount_point = self.activate_resource_disk()
|
||||
if mount_point is not None and \
|
||||
conf.get_switch("ResourceDisk.EnableSwap", False):
|
||||
self.enable_swap(mount_point)
|
||||
|
||||
def activateResourceDisk(self):
|
||||
logger.Info("Activate resource disk")
|
||||
def activate_resource_disk(self):
|
||||
logger.info("Activate resource disk")
|
||||
try:
|
||||
mountpoint = conf.Get("ResourceDisk.MountPoint", "/mnt/resource")
|
||||
fs = conf.Get("ResourceDisk.Filesystem", "ext3")
|
||||
mountpoint = self.mountResourceDisk(mountpoint, fs)
|
||||
warningFile = os.path.join(mountpoint, DataLossWarningFile)
|
||||
mount_point = conf.get("ResourceDisk.MountPoint", "/mnt/resource")
|
||||
fs = conf.get("ResourceDisk.Filesystem", "ext3")
|
||||
mount_point = self.mount_resource_disk(mount_point, fs)
|
||||
warning_file = os.path.join(mount_point, DATALOSS_WARNING_FILE_NAME)
|
||||
try:
|
||||
fileutil.SetFileContents(warningFile, DataLossWarning)
|
||||
fileutil.write_file(warning_file, DATA_LOSS_WARNING)
|
||||
except IOError as e:
|
||||
logger.Warn("Failed to write data loss warnning:{0}", e)
|
||||
return mountpoint
|
||||
logger.warn("Failed to write data loss warnning:{0}", e)
|
||||
return mount_point
|
||||
except ResourceDiskError as e:
|
||||
logger.Error("Failed to mount resource disk {0}", e)
|
||||
AddExtensionEvent(name="WALA", isSuccess=False, message=str(e),
|
||||
logger.error("Failed to mount resource disk {0}", e)
|
||||
add_event(name="WALA", is_success=False, message=str(e),
|
||||
op=WALAEventOperation.ActivateResourceDisk)
|
||||
|
||||
def enableSwap(self, mountpoint):
|
||||
logger.Info("Enable swap")
|
||||
try:
|
||||
sizeMB = conf.GetInt("ResourceDisk.SwapSizeMB", 0)
|
||||
self.createSwapSpace(mountpoint, sizeMB)
|
||||
except ResourceDiskError as e:
|
||||
logger.Error("Failed to enable swap {0}", e)
|
||||
|
||||
def mountResourceDisk(self, mountpoint, fs):
|
||||
device = OSUtil.DeviceForIdePort(1)
|
||||
def enable_swap(self, mount_point):
|
||||
logger.info("Enable swap")
|
||||
try:
|
||||
size_mb = conf.get_int("ResourceDisk.SwapSizeMB", 0)
|
||||
self.create_swap_space(mount_point, size_mb)
|
||||
except ResourceDiskError as e:
|
||||
logger.error("Failed to enable swap {0}", e)
|
||||
|
||||
def mount_resource_disk(self, mount_point, fs):
|
||||
device = OSUTIL.device_for_ide_port(1)
|
||||
if device is None:
|
||||
raise ResourceDiskError("unable to detect disk topology")
|
||||
|
||||
device = "/dev/" + device
|
||||
mountlist = shellutil.RunGetOutput("mount")[1]
|
||||
existing = OSUtil.GetMountPoint(mountlist, device)
|
||||
mountlist = shellutil.run_get_output("mount")[1]
|
||||
existing = OSUTIL.get_mount_point(mountlist, device)
|
||||
|
||||
if(existing):
|
||||
logger.Info("Resource disk {0}1 is already mounted", device)
|
||||
logger.info("Resource disk {0}1 is already mounted", device)
|
||||
return existing
|
||||
|
||||
fileutil.CreateDir(mountpoint, mode=0755)
|
||||
|
||||
logger.Info("Detect GPT...")
|
||||
fileutil.mkdir(mount_point, mode=0755)
|
||||
|
||||
logger.info("Detect GPT...")
|
||||
partition = device + "1"
|
||||
ret = shellutil.RunGetOutput("parted {0} print".format(device))
|
||||
ret = shellutil.run_get_output("parted {0} print".format(device))
|
||||
if ret[0]:
|
||||
raise ResourceDiskError("({0}) {1}".format(device, ret[1]))
|
||||
|
||||
|
||||
if "gpt" in ret[1]:
|
||||
logger.Info("GPT detected")
|
||||
logger.Info("Get GPT partitions")
|
||||
parts = filter(lambda x : re.match("^\s*[0-9]+", x),
|
||||
logger.info("GPT detected")
|
||||
logger.info("Get GPT partitions")
|
||||
parts = filter(lambda x : re.match("^\s*[0-9]+", x),
|
||||
ret[1].split("\n"))
|
||||
logger.Info("Found more than {0} GPT partitions.", len(parts))
|
||||
logger.info("Found more than {0} GPT partitions.", len(parts))
|
||||
if len(parts) > 1:
|
||||
logger.Info("Remove old GPT partitions")
|
||||
logger.info("Remove old GPT partitions")
|
||||
for i in range(1, len(parts) + 1):
|
||||
logger.Info("Remove partition: {0}", i)
|
||||
shellutil.Run("parted {0} rm {1}".format(device, i))
|
||||
logger.info("Remove partition: {0}", i)
|
||||
shellutil.run("parted {0} rm {1}".format(device, i))
|
||||
|
||||
logger.Info("Create a new GPT partition using entire disk space")
|
||||
shellutil.Run("parted {0} mkpart primary 0% 100%".format(device))
|
||||
|
||||
logger.Info("Format partition: {0} with fstype {1}",partition,fs)
|
||||
shellutil.Run("mkfs." + fs + " " + partition + " -F")
|
||||
logger.info("Create a new GPT partition using entire disk space")
|
||||
shellutil.run("parted {0} mkpart primary 0% 100%".format(device))
|
||||
|
||||
logger.info("Format partition: {0} with fstype {1}",partition,fs)
|
||||
shellutil.run("mkfs." + fs + " " + partition + " -F")
|
||||
else:
|
||||
logger.Info("GPT not detected")
|
||||
logger.Info("Check fstype")
|
||||
ret = shellutil.RunGetOutput("sfdisk -q -c {0} 1".format(device))
|
||||
logger.info("GPT not detected")
|
||||
logger.info("Check fstype")
|
||||
ret = shellutil.run_get_output("sfdisk -q -c {0} 1".format(device))
|
||||
if ret[1].rstrip() == "7" and fs != "ntfs":
|
||||
logger.Info("The partition is formatted with ntfs")
|
||||
logger.Info("Format partition: {0} with fstype {1}",partition,fs)
|
||||
shellutil.Run("sfdisk -c {0} 1 83".format(device))
|
||||
shellutil.Run("mkfs." + fs + " " + partition + " -F")
|
||||
logger.info("The partition is formatted with ntfs")
|
||||
logger.info("Format partition: {0} with fstype {1}",partition,fs)
|
||||
shellutil.run("sfdisk -c {0} 1 83".format(device))
|
||||
shellutil.run("mkfs." + fs + " " + partition + " -F")
|
||||
|
||||
logger.Info("Mount resource disk")
|
||||
retCode = shellutil.Run("mount {0} {1}".format(partition, mountpoint),
|
||||
logger.info("Mount resource disk")
|
||||
ret = shellutil.run("mount {0} {1}".format(partition, mount_point),
|
||||
chk_err=False)
|
||||
if retCode:
|
||||
logger.Warn("Failed to mount resource disk. Retry mounting")
|
||||
shellutil.Run("mkfs." + fs + " " + partition + " -F")
|
||||
retCode = shellutil.Run("mount {0} {1}".format(partition, mountpoint))
|
||||
if retCode:
|
||||
raise ResourceDiskError("({0}) {1}".format(partition, retCode))
|
||||
if ret:
|
||||
logger.warn("Failed to mount resource disk. Retry mounting")
|
||||
shellutil.run("mkfs." + fs + " " + partition + " -F")
|
||||
ret = shellutil.run("mount {0} {1}".format(partition, mount_point))
|
||||
if ret:
|
||||
raise ResourceDiskError("({0}) {1}".format(partition, ret))
|
||||
|
||||
logger.Info("Resource disk ({0}) is mounted at {1} with fstype {2}",
|
||||
device, mountpoint, fs)
|
||||
return mountpoint
|
||||
logger.info("Resource disk ({0}) is mounted at {1} with fstype {2}",
|
||||
device, mount_point, fs)
|
||||
return mount_point
|
||||
|
||||
def createSwapSpace(self, mountpoint, sizeMB):
|
||||
sizeKB = sizeMB * 1024
|
||||
size = sizeKB * 1024
|
||||
swapfile = os.path.join(mountpoint, 'swapfile')
|
||||
swapList = shellutil.RunGetOutput("swapon -s")[1]
|
||||
def create_swap_space(self, mount_point, size_mb):
|
||||
size_kb = size_mb * 1024
|
||||
size = size_kb * 1024
|
||||
swapfile = os.path.join(mount_point, 'swapfile')
|
||||
swaplist = shellutil.run_get_output("swapon -s")[1]
|
||||
|
||||
if swapfile in swapList and os.path.getsize(swapfile) == size:
|
||||
logger.Info("Swap already enabled")
|
||||
return
|
||||
if swapfile in swaplist and os.path.getsize(swapfile) == size:
|
||||
logger.info("Swap already enabled")
|
||||
return
|
||||
|
||||
if os.path.isfile(swapfile) and os.path.getsize(swapfile) != size:
|
||||
logger.Info("Remove old swap file")
|
||||
shellutil.Run("swapoff -a", chk_err=False)
|
||||
logger.info("Remove old swap file")
|
||||
shellutil.run("swapoff -a", chk_err=False)
|
||||
os.remove(swapfile)
|
||||
|
||||
if not os.path.isfile(swapfile):
|
||||
logger.Info("Create swap file")
|
||||
shellutil.Run(("dd if=/dev/zero of={0} bs=1024 "
|
||||
"count={1}").format(swapfile, sizeKB))
|
||||
shellutil.Run("mkswap {0}".format(swapfile))
|
||||
if shellutil.Run("swapon {0}".format(swapfile)):
|
||||
logger.info("Create swap file")
|
||||
shellutil.run(("dd if=/dev/zero of={0} bs=1024 "
|
||||
"count={1}").format(swapfile, size_kb))
|
||||
shellutil.run("mkswap {0}".format(swapfile))
|
||||
if shellutil.run("swapon {0}".format(swapfile)):
|
||||
raise ResourceDiskError("{0}".format(swapfile))
|
||||
logger.Info("Enabled {0}KB of swap at {1}".format(sizeKB, swapfile))
|
||||
logger.info("Enabled {0}KB of swap at {1}".format(size_kb, swapfile))
|
||||
|
||||
|
||||
@@ -21,61 +21,61 @@ import os
|
||||
import time
|
||||
import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.conf as conf
|
||||
from azurelinuxagent.metadata import GuestAgentLongName, GuestAgentVersion, \
|
||||
DistroName, DistroVersion, DistroFullName
|
||||
from azurelinuxagent.metadata import agent_long_name, AGENT_VERSION, \
|
||||
DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME
|
||||
import azurelinuxagent.protocol as prot
|
||||
import azurelinuxagent.event as event
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
|
||||
|
||||
class RunHandler(object):
|
||||
class MainHandler(object):
|
||||
def __init__(self, handlers):
|
||||
self.handlers = handlers
|
||||
|
||||
def run(self):
|
||||
logger.Info("{0} Version:{1}", GuestAgentLongName, GuestAgentVersion)
|
||||
logger.Info("OS: {0} {1}", DistroName, DistroVersion)
|
||||
logger.info("{0} Version:{1}", agent_long_name, AGENT_VERSION)
|
||||
logger.info("OS: {0} {1}", DISTRO_NAME, DISTRO_VERSION)
|
||||
|
||||
event.EnableUnhandledErrorDump("Azure Linux Agent")
|
||||
fileutil.SetFileContents(OSUtil.GetAgentPidPath(),
|
||||
event.enable_unhandled_err_dump("Azure Linux Agent")
|
||||
fileutil.write_file(OSUTIL.get_agent_pid_file_path(),
|
||||
str(os.getpid()))
|
||||
|
||||
if conf.GetSwitch("DetectScvmmEnv", False):
|
||||
if self.handlers.scvmmHandler.detectScvmmEnv():
|
||||
if conf.get_switch("DetectScvmmEnv", False):
|
||||
if self.handlers.scvmm_handler.detect_scvmm_env():
|
||||
return
|
||||
|
||||
self.handlers.dhcpHandler.probe()
|
||||
|
||||
prot.DetectDefaultProtocol()
|
||||
|
||||
event.EventMonitor().startEventsLoop()
|
||||
self.handlers.dhcp_handler.probe()
|
||||
|
||||
self.handlers.provisionHandler.process()
|
||||
prot.detect_default_protocol()
|
||||
|
||||
if conf.GetSwitch("ResourceDisk.Format", False):
|
||||
self.handlers.resourceDiskHandler.startActivateResourceDisk()
|
||||
|
||||
self.handlers.envHandler.startMonitor()
|
||||
event.EventMonitor().start()
|
||||
|
||||
protocol = prot.Factory.getDefaultProtocol()
|
||||
self.handlers.provision_handler.process()
|
||||
|
||||
if conf.get_switch("ResourceDisk.Format", False):
|
||||
self.handlers.resource_disk_handler.start_activate_resource_disk()
|
||||
|
||||
self.handlers.env_handler.start()
|
||||
|
||||
protocol = prot.Factory.get_default_protocol()
|
||||
while True:
|
||||
|
||||
#Handle extensions
|
||||
handlerStatusList = self.handlers.extensionHandler.process()
|
||||
h_status_list = self.handlers.extension_handler.process()
|
||||
|
||||
#Report status
|
||||
vmStatus = prot.VMStatus()
|
||||
vmStatus.vmAgent.agentVersion = GuestAgentLongName
|
||||
vmStatus.vmAgent.status = "Ready"
|
||||
vmStatus.vmAgent.message = "Guest Agent is running"
|
||||
for handlerStatus in handlerStatusList:
|
||||
vmStatus.extensionHandlers.append(handlerStatus)
|
||||
vm_status = prot.VMStatus()
|
||||
vm_status.vmAgent.agentVersion = agent_long_name
|
||||
vm_status.vmAgent.status = "Ready"
|
||||
vm_status.vmAgent.message = "Guest Agent is running"
|
||||
for h_status in h_status_list:
|
||||
vm_status.extensionHandlers.append(h_status)
|
||||
try:
|
||||
logger.Info("Report vm status")
|
||||
protocol.reportStatus(vmStatus)
|
||||
logger.info("Report vm status")
|
||||
protocol.report_status(vm_status)
|
||||
except prot.ProtocolError as e:
|
||||
logger.Error("Failed to report vm status: {0}", e)
|
||||
logger.error("Failed to report vm status: {0}", e)
|
||||
|
||||
time.sleep(25)
|
||||
|
||||
|
||||
@@ -20,28 +20,28 @@
|
||||
import os
|
||||
import subprocess
|
||||
import azurelinuxagent.logger as logger
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
|
||||
VmmConfigFileName = "linuxosconfiguration.xml"
|
||||
VmmStartupScriptName= "install"
|
||||
VMM_CONF_FILE_NAME = "linuxosconfiguration.xml"
|
||||
VMM_STARTUP_SCRIPT_NAME= "install"
|
||||
|
||||
class ScvmmHandler(object):
|
||||
|
||||
def detectScvmmEnv(self):
|
||||
logger.Info("Detecting Microsoft System Center VMM Environment")
|
||||
OSUtil.MountDvd(maxRetry=1, chk_err=False)
|
||||
mountPoint = OSUtil.GetDvdMountPoint()
|
||||
found = os.path.isfile(os.path.join(mountPoint, VmmConfigFileName))
|
||||
if found:
|
||||
self.startScvmmAgent()
|
||||
|
||||
def detect_scvmm_env(self):
|
||||
logger.info("Detecting Microsoft System Center VMM Environment")
|
||||
OSUTIL.mount_dvd(max_retry=1, chk_err=False)
|
||||
mount_point = OSUTIL.get_dvd_mount_point()
|
||||
found = os.path.isfile(os.path.join(mount_point, VMM_CONF_FILE_NAME))
|
||||
if found:
|
||||
self.start_scvmm_agent()
|
||||
else:
|
||||
OSUtil.UmountDvd(chk_err=False)
|
||||
OSUTIL.umount_dvd(chk_err=False)
|
||||
return found
|
||||
|
||||
def startScvmmAgent(self):
|
||||
logger.Info("Starting Microsoft System Center VMM Initialization "
|
||||
def start_scvmm_agent(self):
|
||||
logger.info("Starting Microsoft System Center VMM Initialization "
|
||||
"Process")
|
||||
mountPoint = OSUtil.GetDvdMountPoint()
|
||||
startupScript = os.path.join(mountPoint, VmmStartupScriptName)
|
||||
subprocess.Popen(["/bin/bash", startupScript, "-p " + mountPoint])
|
||||
mount_point = OSUTIL.get_dvd_mount_point()
|
||||
startup_script = os.path.join(mount_point, VMM_STARTUP_SCRIPT_NAME)
|
||||
subprocess.Popen(["/bin/bash", startup_script, "-p " + mount_point])
|
||||
|
||||
|
||||
@@ -16,31 +16,31 @@
|
||||
#
|
||||
|
||||
import azurelinuxagent.logger as logger
|
||||
from azurelinuxagent.metadata import DistroName
|
||||
import azurelinuxagent.distro.default.loader as defaultLoader
|
||||
from azurelinuxagent.metadata import DISTRO_NAME
|
||||
import azurelinuxagent.distro.default.loader as default_loader
|
||||
|
||||
|
||||
def GetDistroLoader():
|
||||
def get_distro_loader():
|
||||
try:
|
||||
logger.Verbose("Loading distro implemetation from: {0}", DistroName)
|
||||
pkgName = "azurelinuxagent.distro.{0}.loader".format(DistroName)
|
||||
return __import__(pkgName, fromlist="loader")
|
||||
logger.verb("Loading distro implemetation from: {0}", DISTRO_NAME)
|
||||
pkg_name = "azurelinuxagent.distro.{0}.loader".format(DISTRO_NAME)
|
||||
return __import__(pkg_name, fromlist="loader")
|
||||
except ImportError as e:
|
||||
logger.Warn("Unable to load distro implemetation for {0}.", DistroName)
|
||||
logger.Warn("Use default distro implemetation instead.")
|
||||
return defaultLoader
|
||||
logger.warn("Unable to load distro implemetation for {0}.", DISTRO_NAME)
|
||||
logger.warn("Use default distro implemetation instead.")
|
||||
return default_loader
|
||||
|
||||
distroLoader = GetDistroLoader()
|
||||
DISTRO_LOADER = get_distro_loader()
|
||||
|
||||
def GetOSUtil():
|
||||
def get_osutil():
|
||||
try:
|
||||
return distroLoader.GetOSUtil()
|
||||
return DISTRO_LOADER.get_osutil()
|
||||
except AttributeError:
|
||||
return defaultLoader.GetOSUtil()
|
||||
return default_loader.get_osutil()
|
||||
|
||||
def GetHandlers():
|
||||
def get_handlers():
|
||||
try:
|
||||
return distroLoader.GetHandlers()
|
||||
return DISTRO_LOADER.get_handlers()
|
||||
except AttributeError:
|
||||
return defaultLoader.GetHandlers()
|
||||
return default_loader.get_handlers()
|
||||
|
||||
|
||||
@@ -17,9 +17,9 @@
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
|
||||
from azurelinuxagent.metadata import DistroName, DistroVersion
|
||||
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION
|
||||
import azurelinuxagent.distro.redhat.loader as redhat
|
||||
|
||||
def GetOSUtil():
|
||||
return redhat.GetOSUtil()
|
||||
|
||||
def get_osutil():
|
||||
return redhat.get_osutil()
|
||||
|
||||
|
||||
@@ -17,11 +17,11 @@
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
|
||||
from azurelinuxagent.metadata import DistroName, DistroVersion
|
||||
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION
|
||||
|
||||
def GetOSUtil():
|
||||
def get_osutil():
|
||||
from azurelinuxagent.distro.redhat.osutil import Redhat6xOSUtil, RedhatOSUtil
|
||||
if DistroVersion < "7":
|
||||
if DISTRO_VERSION < "7":
|
||||
return Redhat6xOSUtil()
|
||||
else:
|
||||
return RedhatOSUtil()
|
||||
|
||||
@@ -30,54 +30,54 @@ import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
import azurelinuxagent.utils.textutil as textutil
|
||||
from azurelinuxagent.distro.default.osutil import OSUtil, OSUtilError
|
||||
from azurelinuxagent.distro.default.osutil import DefaultOSUtil, OSUtilError
|
||||
|
||||
class Redhat6xOSUtil(OSUtil):
|
||||
class Redhat6xOSUtil(DefaultOSUtil):
|
||||
def __init__(self):
|
||||
super(Redhat6xOSUtil, self).__init__()
|
||||
self.sshdConfigPath = '/etc/ssh/sshd_config'
|
||||
self.opensslCmd = '/usr/bin/openssl'
|
||||
self.configPath = '/etc/waagent.conf'
|
||||
self.sshd_conf_file_path = '/etc/ssh/sshd_config'
|
||||
self.openssl_cmd = '/usr/bin/openssl'
|
||||
self.conf_file_path = '/etc/waagent.conf'
|
||||
self.selinux=None
|
||||
|
||||
def StartNetwork(self):
|
||||
return shellutil.Run("/sbin/service networking start", chk_err=False)
|
||||
def start_network(self):
|
||||
return shellutil.run("/sbin/service networking start", chk_err=False)
|
||||
|
||||
def RestartSshService(self):
|
||||
return shellutil.Run("/sbin/service sshd condrestart", chk_err=False)
|
||||
def restart_ssh_service(self):
|
||||
return shellutil.run("/sbin/service sshd condrestart", chk_err=False)
|
||||
|
||||
def StopAgentService(self):
|
||||
return shellutil.Run("/sbin/service waagent stop", chk_err=False)
|
||||
def stop_agent_service(self):
|
||||
return shellutil.run("/sbin/service waagent stop", chk_err=False)
|
||||
|
||||
def StartAgentService(self):
|
||||
return shellutil.Run("/sbin/service waagent start", chk_err=False)
|
||||
def start_agent_service(self):
|
||||
return shellutil.run("/sbin/service waagent start", chk_err=False)
|
||||
|
||||
def RegisterAgentService(self):
|
||||
return shellutil.Run("chkconfig --add waagent", chk_err=False)
|
||||
|
||||
def UnregisterAgentService(self):
|
||||
return shellutil.Run("chkconfig --del waagent", chk_err=False)
|
||||
def register_agent_service(self):
|
||||
return shellutil.run("chkconfig --add waagent", chk_err=False)
|
||||
|
||||
def RsaPublicKeyToSshRsa(self, publicKey):
|
||||
lines = publicKey.split("\n")
|
||||
def unregister_agent_service(self):
|
||||
return shellutil.run("chkconfig --del waagent", chk_err=False)
|
||||
|
||||
def asn1_to_ssh_rsa(self, pubkey):
|
||||
lines = pubkey.split("\n")
|
||||
lines = filter(lambda x : not x.startswith("----"), lines)
|
||||
base64Encoded = "".join(lines)
|
||||
base64_encoded = "".join(lines)
|
||||
try:
|
||||
#TODO remove pyasn1 dependency
|
||||
from pyasn1.codec.der import decoder as der_decoder
|
||||
derEncoded = base64.b64decode(base64Encoded)
|
||||
derEncoded = der_decoder.decode(derEncoded)[0][1]
|
||||
k = der_decoder.decode(textutil.BitsToString(derEncoded))[0]
|
||||
der_encoded = base64.b64decode(base64_encoded)
|
||||
der_encoded = der_decoder.decode(der_encoded)[0][1]
|
||||
k = der_decoder.decode(textutil.bits_to_str(der_encoded))[0]
|
||||
n=k[0]
|
||||
e=k[1]
|
||||
keydata=""
|
||||
keydata += struct.pack('>I',len("ssh-rsa"))
|
||||
keydata += "ssh-rsa"
|
||||
keydata += struct.pack('>I',len(textutil.NumberToBytes(e)))
|
||||
keydata += textutil.NumberToBytes(e)
|
||||
keydata += struct.pack('>I',len(textutil.NumberToBytes(n)) + 1)
|
||||
keydata += struct.pack('>I',len(textutil.num_to_bytes(e)))
|
||||
keydata += textutil.num_to_bytes(e)
|
||||
keydata += struct.pack('>I',len(textutil.num_to_bytes(n)) + 1)
|
||||
keydata += "\0"
|
||||
keydata += textutil.NumberToBytes(n)
|
||||
keydata += textutil.num_to_bytes(n)
|
||||
return "ssh-rsa " + base64.b64encode(keydata) + "\n"
|
||||
except ImportError as e:
|
||||
raise OSUtilError("Failed to load pyasn1.codec.der")
|
||||
@@ -85,37 +85,37 @@ class Redhat6xOSUtil(OSUtil):
|
||||
raise OSUtilError(("Failed to convert public key: {0} {1}"
|
||||
"").format(type(e).__name__, e))
|
||||
|
||||
def OpenSslToOpenSsh(self, inputFile, outputFile):
|
||||
publicKey = fileutil.GetFileContents(inputFile)
|
||||
sshRsaPublicKey = self.RsaPublicKeyToSshRsa(publicKey)
|
||||
fileutil.SetFileContents(outputFile, sshRsaPublicKey)
|
||||
def openssl_to_openssh(self, input_file, output_file):
|
||||
pubkey = fileutil.read_file(input_file)
|
||||
ssh_rsa_pubkey = self.asn1_to_ssh_rsa(pubkey)
|
||||
fileutil.write_file(output_file, ssh_rsa_pubkey)
|
||||
|
||||
#Override
|
||||
def GetDhcpProcessId(self):
|
||||
ret= shellutil.RunGetOutput("pidof dhclient")
|
||||
def get_dhcp_pid(self):
|
||||
ret= shellutil.run_get_output("pidof dhclient")
|
||||
return ret[1] if ret[0] == 0 else None
|
||||
|
||||
class RedhatOSUtil(Redhat6xOSUtil):
|
||||
def __init__(self):
|
||||
super(RedhatOSUtil, self).__init__()
|
||||
|
||||
def SetHostname(self, hostname):
|
||||
super(RedhatOSUtil, self).SetHostname(hostname)
|
||||
fileutil.UpdateConfigFile('/etc/sysconfig/network',
|
||||
def set_hostname(self, hostname):
|
||||
super(RedhatOSUtil, self).set_hostname(hostname)
|
||||
fileutil.update_conf_file('/etc/sysconfig/network',
|
||||
'HOSTNAME',
|
||||
'HOSTNAME={0}'.format(hostname))
|
||||
|
||||
def SetDhcpHostname(self, hostname):
|
||||
ifname = self.GetInterfaceName()
|
||||
|
||||
def set_dhcp_hostname(self, hostname):
|
||||
ifname = self.get_if_name()
|
||||
filepath = "/etc/sysconfig/network-scripts/ifcfg-{0}".format(ifname)
|
||||
fileutil.UpdateConfigFile(filepath,
|
||||
fileutil.update_conf_file(filepath,
|
||||
'DHCP_HOSTNAME',
|
||||
'DHCP_HOSTNAME={0}'.format(hostname))
|
||||
|
||||
def RegisterAgentService(self):
|
||||
return shellutil.Run("systemctl enable waagent", chk_err=False)
|
||||
|
||||
def UnregisterAgentService(self):
|
||||
return shellutil.Run("systemctl disable waagent", chk_err=False)
|
||||
def register_agent_service(self):
|
||||
return shellutil.run("systemctl enable waagent", chk_err=False)
|
||||
|
||||
def unregister_agent_service(self):
|
||||
return shellutil.run("systemctl disable waagent", chk_err=False)
|
||||
|
||||
|
||||
|
||||
@@ -17,12 +17,12 @@
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
|
||||
from azurelinuxagent.metadata import DistroName, DistroVersion, DistroFullName
|
||||
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME
|
||||
|
||||
def GetOSUtil():
|
||||
def get_osutil():
|
||||
from azurelinuxagent.distro.suse.osutil import SUSE11OSUtil, SUSEOSUtil
|
||||
if DistroFullName=='SUSE Linux Enterprise Server' and DistroVersion < '12' \
|
||||
or DistroFullName == 'openSUSE' and DistroVersion < '13.2':
|
||||
if DISTRO_FULL_NAME=='SUSE Linux Enterprise Server' and DISTRO_VERSION < '12' \
|
||||
or DISTRO_FULL_NAME == 'openSUSE' and DISTRO_VERSION < '13.2':
|
||||
return SUSE11OSUtil()
|
||||
else:
|
||||
return SUSEOSUtil()
|
||||
|
||||
@@ -29,60 +29,60 @@ import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
import azurelinuxagent.utils.textutil as textutil
|
||||
from azurelinuxagent.metadata import DistroName, DistroVersion, DistroFullName
|
||||
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME
|
||||
from azurelinuxagent.distro.default.osutil import DefaultOSUtil
|
||||
|
||||
class SUSE11OSUtil(DefaultOSUtil):
|
||||
def __init__(self):
|
||||
super(SUSE11OSUtil, self).__init__()
|
||||
self.dhcpClientName='dhcpcd'
|
||||
self.dhclient_name='dhcpcd'
|
||||
|
||||
def SetHostname(self, hostname):
|
||||
fileutil.SetFileContents('/etc/HOSTNAME', hostname)
|
||||
shellutil.Run("hostname {0}".format(hostname), chk_err=False)
|
||||
def set_hostname(self, hostname):
|
||||
fileutil.write_file('/etc/HOSTNAME', hostname)
|
||||
shellutil.run("hostname {0}".format(hostname), chk_err=False)
|
||||
|
||||
def GetDhcpProcessId(self):
|
||||
ret= shellutil.RunGetOutput("pidof {0}".format(self.dhcpClientName))
|
||||
def get_dhcp_pid(self):
|
||||
ret= shellutil.run_get_output("pidof {0}".format(self.dhclient_name))
|
||||
return ret[1] if ret[0] == 0 else None
|
||||
|
||||
def IsDhcpEnabled(self):
|
||||
|
||||
def is_dhcp_enabled(self):
|
||||
return True
|
||||
|
||||
def StopDhcpService(self):
|
||||
cmd = "/sbin/service {0} stop".format(self.dhcpClientName)
|
||||
return shellutil.Run(cmd, chk_err=False)
|
||||
def stop_dhcp_service(self):
|
||||
cmd = "/sbin/service {0} stop".format(self.dhclient_name)
|
||||
return shellutil.run(cmd, chk_err=False)
|
||||
|
||||
def StartDhcpService(self):
|
||||
cmd = "/sbin/service {0} start".format(self.dhcpClientName)
|
||||
return shellutil.Run(cmd, chk_err=False)
|
||||
def start_dhcp_service(self):
|
||||
cmd = "/sbin/service {0} start".format(self.dhclient_name)
|
||||
return shellutil.run(cmd, chk_err=False)
|
||||
|
||||
def StartNetwork(self) :
|
||||
return shellutil.Run("/sbin/service start network", chk_err=False)
|
||||
def start_network(self) :
|
||||
return shellutil.run("/sbin/service start network", chk_err=False)
|
||||
|
||||
def RestartSshService(self):
|
||||
return shellutil.Run("/sbin/service sshd restart", chk_err=False)
|
||||
def restart_ssh_service(self):
|
||||
return shellutil.run("/sbin/service sshd restart", chk_err=False)
|
||||
|
||||
def StopAgentService(self):
|
||||
return shellutil.Run("/sbin/service waagent stop", chk_err=False)
|
||||
def stop_agent_service(self):
|
||||
return shellutil.run("/sbin/service waagent stop", chk_err=False)
|
||||
|
||||
def StartAgentService(self):
|
||||
return shellutil.Run("/sbin/service waagent start", chk_err=False)
|
||||
|
||||
def RegisterAgentService(self):
|
||||
return shellutil.Run("/sbin/insserv waagent", chk_err=False)
|
||||
|
||||
def UnregisterAgentService(self):
|
||||
return shellutil.Run("/sbin/insserv -r waagent", chk_err=False)
|
||||
def start_agent_service(self):
|
||||
return shellutil.run("/sbin/service waagent start", chk_err=False)
|
||||
|
||||
def register_agent_service(self):
|
||||
return shellutil.run("/sbin/insserv waagent", chk_err=False)
|
||||
|
||||
def unregister_agent_service(self):
|
||||
return shellutil.run("/sbin/insserv -r waagent", chk_err=False)
|
||||
|
||||
class SUSEOSUtil(SUSE11OSUtil):
|
||||
def __init__(self):
|
||||
super(SUSEOSUtil, self).__init__()
|
||||
self.dhcpClientName = 'wickedd-dhcp4'
|
||||
self.dhclient_name = 'wickedd-dhcp4'
|
||||
|
||||
def RegisterAgentService(self):
|
||||
return shellutil.Run("systemctl enable waagent", chk_err=False)
|
||||
|
||||
def UnregisterAgentService(self):
|
||||
return shellutil.Run("systemctl disable waagent", chk_err=False)
|
||||
def register_agent_service(self):
|
||||
return shellutil.run("systemctl enable waagent", chk_err=False)
|
||||
|
||||
def unregister_agent_service(self):
|
||||
return shellutil.run("systemctl disable waagent", chk_err=False)
|
||||
|
||||
|
||||
|
||||
@@ -22,22 +22,22 @@ import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
from azurelinuxagent.distro.default.deprovision import DeprovisionHandler, DeprovisionAction
|
||||
|
||||
def DeleteResolve():
|
||||
def del_resolv():
|
||||
if os.path.realpath('/etc/resolv.conf') != '/run/resolvconf/resolv.conf':
|
||||
logger.Info("resolvconf is not configured. Removing /etc/resolv.conf")
|
||||
fileutil.RemoveFiles('/etc/resolv.conf')
|
||||
logger.info("resolvconf is not configured. Removing /etc/resolv.conf")
|
||||
fileutil.rm_files('/etc/resolv.conf')
|
||||
else:
|
||||
logger.Info("resolvconf is enabled; leaving /etc/resolv.conf intact")
|
||||
fileutil.RemoveFiles('/etc/resolvconf/resolv.conf.d/tail',
|
||||
logger.info("resolvconf is enabled; leaving /etc/resolv.conf intact")
|
||||
fileutil.rm_files('/etc/resolvconf/resolv.conf.d/tail',
|
||||
'/etc/resolvconf/resolv.conf.d/originial')
|
||||
|
||||
|
||||
class UbuntuDeprovisionHandler(DeprovisionHandler):
|
||||
def setUp(self, deluser):
|
||||
warnings, actions = super(UbuntuDeprovisionHandler, self).setUp(deluser)
|
||||
def setup(self, deluser):
|
||||
warnings, actions = super(UbuntuDeprovisionHandler, self).setup(deluser)
|
||||
warnings.append("WARNING! Nameserver configuration in "
|
||||
"/etc/resolvconf/resolv.conf.d/{tail,originial} "
|
||||
"will be deleted.")
|
||||
actions.append(DeprovisionAction(DeleteResolve))
|
||||
actions.append(DeprovisionAction(del_resolv))
|
||||
return warnings, actions
|
||||
|
||||
|
||||
@@ -17,11 +17,13 @@
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
|
||||
from provision import UbuntuProvisionHandler
|
||||
from azurelinuxagent.distro.ubuntu.provision import UbuntuProvisionHandler
|
||||
from azurelinuxagent.distro.ubuntu.deprovision import UbuntuDeprovisionHandler
|
||||
from azurelinuxagent.distro.default.handlerFactory import DefaultHandlerFactory
|
||||
|
||||
class UbuntuHandlerFactory(DefaultHandlerFactory):
|
||||
def __init__(self):
|
||||
super(UbuntuHandlerFactory, self).__init__()
|
||||
self.provisionHandler = UbuntuProvisionHandler()
|
||||
self.provision_handler = UbuntuProvisionHandler()
|
||||
self.deprovision_handler = UbuntuDeprovisionHandler()
|
||||
|
||||
|
||||
@@ -17,20 +17,20 @@
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
|
||||
from azurelinuxagent.metadata import DistroName, DistroVersion
|
||||
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION
|
||||
|
||||
def GetOSUtil():
|
||||
def get_osutil():
|
||||
from azurelinuxagent.distro.ubuntu.osutil import Ubuntu1204OSUtil, \
|
||||
UbuntuOSUtil, \
|
||||
Ubuntu14xOSUtil
|
||||
if DistroVersion == "12.04":
|
||||
if DISTRO_VERSION == "12.04":
|
||||
return Ubuntu1204OSUtil()
|
||||
elif DistroVersion == "14.04" or DistroVersion == "14.10":
|
||||
elif DISTRO_VERSION == "14.04" or DISTRO_VERSION == "14.10":
|
||||
return Ubuntu14xOSUtil()
|
||||
else:
|
||||
return UbuntuOSUtil()
|
||||
|
||||
def GetHandlers():
|
||||
def get_handlers():
|
||||
from azurelinuxagent.distro.ubuntu.handlerFactory import UbuntuHandlerFactory
|
||||
return UbuntuHandlerFactory()
|
||||
|
||||
|
||||
|
||||
@@ -35,31 +35,31 @@ class Ubuntu14xOSUtil(DefaultOSUtil):
|
||||
def __init__(self):
|
||||
super(Ubuntu14xOSUtil, self).__init__()
|
||||
|
||||
def StartNetwork(self):
|
||||
return shellutil.Run("service networking start", chk_err=False)
|
||||
|
||||
def StopAgentService(self):
|
||||
return shellutil.Run("service walinuxagent stop", chk_err=False)
|
||||
def start_network(self):
|
||||
return shellutil.run("service networking start", chk_err=False)
|
||||
|
||||
def StartAgentService(self):
|
||||
return shellutil.Run("service walinuxagent start", chk_err=False)
|
||||
def stop_agent_service(self):
|
||||
return shellutil.run("service walinuxagent stop", chk_err=False)
|
||||
|
||||
def start_agent_service(self):
|
||||
return shellutil.run("service walinuxagent start", chk_err=False)
|
||||
|
||||
class Ubuntu1204OSUtil(Ubuntu14xOSUtil):
|
||||
def __init__(self):
|
||||
super(Ubuntu1204OSUtil, self).__init__()
|
||||
|
||||
#Override
|
||||
def GetDhcpProcessId(self):
|
||||
ret= shellutil.RunGetOutput("pidof dhclient3")
|
||||
def get_dhcp_pid(self):
|
||||
ret= shellutil.run_get_output("pidof dhclient3")
|
||||
return ret[1] if ret[0] == 0 else None
|
||||
|
||||
class UbuntuOSUtil(Ubuntu14xOSUtil):
|
||||
def __init__(self):
|
||||
super(UbuntuOSUtil, self).__init__()
|
||||
|
||||
def RegisterAgentService(self):
|
||||
return shellutil.Run("systemctl unmask walinuxagent", chk_err=False)
|
||||
|
||||
def UnregisterAgentService(self):
|
||||
return shellutil.Run("systemctl mask walinuxagent", chk_err=False)
|
||||
def register_agent_service(self):
|
||||
return shellutil.run("systemctl unmask walinuxagent", chk_err=False)
|
||||
|
||||
def unregister_agent_service(self):
|
||||
return shellutil.run("systemctl mask walinuxagent", chk_err=False)
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.conf as conf
|
||||
import azurelinuxagent.protocol as prot
|
||||
from azurelinuxagent.exception import *
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
from azurelinuxagent.distro.default.provision import ProvisionHandler
|
||||
@@ -34,38 +34,38 @@ On ubuntu image, provision could be disabled.
|
||||
class UbuntuProvisionHandler(ProvisionHandler):
|
||||
def process(self):
|
||||
#If provision is enabled, run default provision handler
|
||||
if conf.GetSwitch("Provisioning.Enabled", False):
|
||||
if conf.get_switch("Provisioning.Enabled", False):
|
||||
super(UbuntuProvisionHandler, self).process()
|
||||
return
|
||||
|
||||
logger.Info("Run Ubuntu provision handler")
|
||||
provisioned = os.path.join(OSUtil.GetLibDir(), "provisioned")
|
||||
logger.info("run Ubuntu provision handler")
|
||||
provisioned = os.path.join(OSUTIL.get_lib_dir(), "provisioned")
|
||||
if os.path.isfile(provisioned):
|
||||
return
|
||||
|
||||
logger.Info("Waiting cloud-init to finish provisioning.")
|
||||
protocol = prot.Factory.getDefaultProtocol()
|
||||
logger.info("Waiting cloud-init to finish provisioning.")
|
||||
protocol = prot.Factory.get_default_protocol()
|
||||
try:
|
||||
logger.Info("Wait for ssh host key to be generated.")
|
||||
thumbprint = self.waitForSshHostKey()
|
||||
fileutil.SetFileContents(provisioned, "")
|
||||
logger.info("Wait for ssh host key to be generated.")
|
||||
thumbprint = self.wait_for_ssh_host_key()
|
||||
fileutil.write_file(provisioned, "")
|
||||
|
||||
logger.Info("Finished provisioning")
|
||||
logger.info("Finished provisioning")
|
||||
status = prot.ProvisionStatus(status="Ready")
|
||||
status.properties.certificateThumbprint = thumbprint
|
||||
protocol.reportProvisionStatus(status)
|
||||
protocol.report_provision_status(status)
|
||||
|
||||
except ProvisionError as e:
|
||||
logger.Error("Provision failed: {0}", e)
|
||||
protocol.reportProvisionStatus(status="NotReady", subStatus=str(e))
|
||||
logger.error("Provision failed: {0}", e)
|
||||
protocol.report_provision_status(status="NotReady", subStatus=str(e))
|
||||
|
||||
def waitForSshHostKey(self, maxRetry=60):
|
||||
keyPairType = conf.Get("Provisioning.SshHostKeyPairType", "rsa")
|
||||
path = '/etc/ssh/ssh_host_{0}_key'.format(keyPairType)
|
||||
for retry in range(0, maxRetry):
|
||||
def wait_for_ssh_host_key(self, max_retry=60):
|
||||
kepair_type = conf.get("Provisioning.SshHostKeyPairType", "rsa")
|
||||
path = '/etc/ssh/ssh_host_{0}_key'.format(kepair_type)
|
||||
for retry in range(0, max_retry):
|
||||
if os.path.isfile(path):
|
||||
return self.getSshHostKeyThumbprint(keyPairType)
|
||||
if retry < maxRetry - 1:
|
||||
logger.Info("Wait for ssh host key be generated: {0}", path)
|
||||
return self.get_ssh_host_key_thumbprint(kepair_type)
|
||||
if retry < max_retry - 1:
|
||||
logger.info("Wait for ssh host key be generated: {0}", path)
|
||||
time.sleep(5)
|
||||
raise ProvisionError("Ssh hsot key is not generated.")
|
||||
|
||||
+89
-89
@@ -26,9 +26,9 @@ import threading
|
||||
import platform
|
||||
import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.protocol as prot
|
||||
from azurelinuxagent.metadata import DistroName, DistroVersion, DistroCodeName,\
|
||||
GuestAgentVersion
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.metadata import DISTRO_NAME, DISTRO_VERSION, DISTRO_CODE_NAME,\
|
||||
AGENT_VERSION
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
|
||||
class EventError(Exception):
|
||||
pass
|
||||
@@ -42,108 +42,108 @@ class WALAEventOperation:
|
||||
Enable = "Enable"
|
||||
Download = "Download"
|
||||
Upgrade = "Upgrade"
|
||||
Update = "Update"
|
||||
Update = "Update"
|
||||
ActivateResourceDisk="ActivateResourceDisk"
|
||||
UnhandledError="UnhandledError"
|
||||
|
||||
|
||||
class EventMonitor(object):
|
||||
def __init__(self):
|
||||
self.sysInfo = []
|
||||
self.eventDir = os.path.join(OSUtil.GetLibDir(), "events")
|
||||
self.initSystemInfo()
|
||||
self.sysinfo = []
|
||||
self.event_dir = os.path.join(OSUTIL.get_lib_dir(), "events")
|
||||
self.init_sysinfo()
|
||||
|
||||
def initSystemInfo(self):
|
||||
osversion = "{0}:{1}-{2}-{3}:{4}".format(platform.system(),
|
||||
DistroName,
|
||||
DistroVersion,
|
||||
DistroCodeName,
|
||||
def init_sysinfo(self):
|
||||
osversion = "{0}:{1}-{2}-{3}:{4}".format(platform.system(),
|
||||
DISTRO_NAME,
|
||||
DISTRO_VERSION,
|
||||
DISTRO_CODE_NAME,
|
||||
platform.release())
|
||||
self.sysInfo.append(prot.TelemetryEventParam("OSVersion", osversion))
|
||||
self.sysInfo.append(prot.TelemetryEventParam("GAVersion",
|
||||
GuestAgentVersion))
|
||||
self.sysInfo.append(prot.TelemetryEventParam("RAM",
|
||||
OSUtil.GetTotalMemory()))
|
||||
self.sysInfo.append(prot.TelemetryEventParam("Processors",
|
||||
OSUtil.GetProcessorCores()))
|
||||
protocol = prot.Factory.getDefaultProtocol()
|
||||
metadata = protocol.getInstanceMetadata()
|
||||
self.sysInfo.append(prot.TelemetryEventParam("TenantName",
|
||||
self.sysinfo.append(prot.TelemetryEventParam("OSVersion", osversion))
|
||||
self.sysinfo.append(prot.TelemetryEventParam("GAVersion",
|
||||
AGENT_VERSION))
|
||||
self.sysinfo.append(prot.TelemetryEventParam("RAM",
|
||||
OSUTIL.get_total_mem()))
|
||||
self.sysinfo.append(prot.TelemetryEventParam("Processors",
|
||||
OSUTIL.get_processor_cores()))
|
||||
protocol = prot.Factory.get_default_protocol()
|
||||
metadata = protocol.get_instance_metadata()
|
||||
self.sysinfo.append(prot.TelemetryEventParam("TenantName",
|
||||
metadata.deploymentName))
|
||||
self.sysInfo.append(prot.TelemetryEventParam("RoleName",
|
||||
self.sysinfo.append(prot.TelemetryEventParam("RoleName",
|
||||
metadata.roleName))
|
||||
self.sysInfo.append(prot.TelemetryEventParam("RoleInstanceName",
|
||||
self.sysinfo.append(prot.TelemetryEventParam("RoleInstanceName",
|
||||
metadata.roleInstanceId))
|
||||
self.sysInfo.append(prot.TelemetryEventParam("ContainerId",
|
||||
self.sysinfo.append(prot.TelemetryEventParam("ContainerId",
|
||||
metadata.containerId))
|
||||
|
||||
def startEventsLoop(self):
|
||||
eventThread = threading.Thread(target = self.eventsLoop)
|
||||
eventThread.setDaemon(True)
|
||||
eventThread.start()
|
||||
def start(self):
|
||||
event_thread = threading.Thread(target = self.run)
|
||||
event_thread.setDaemon(True)
|
||||
event_thread.start()
|
||||
|
||||
def collectEvent(self, eventFilePath):
|
||||
def collect_event(self, evt_file_name):
|
||||
try:
|
||||
with open(eventFilePath, "rb") as hfile:
|
||||
with open(evt_file_name, "rb") as evt_file:
|
||||
#if fail to open or delete the file, throw exception
|
||||
jsonStr = hfile.read().decode("utf-8",'ignore')
|
||||
os.remove(eventFilePath)
|
||||
return jsonStr
|
||||
json_str = evt_file.read().decode("utf-8",'ignore')
|
||||
os.remove(evt_file_name)
|
||||
return json_str
|
||||
except IOError as e:
|
||||
msg = "Failed to process {0}, {1}".format(eventFilePath, e)
|
||||
msg = "Failed to process {0}, {1}".format(evt_file_name, e)
|
||||
raise EventError(msg)
|
||||
|
||||
def collectAndSendEvents(self):
|
||||
eventList = prot.TelemetryEventList()
|
||||
eventFiles = os.listdir(self.eventDir)
|
||||
for eventFile in eventFiles:
|
||||
if not eventFile.endswith(".tld"):
|
||||
def collect_and_send_events(self):
|
||||
event_list = prot.TelemetryEventList()
|
||||
event_files = os.listdir(self.event_dir)
|
||||
for event_file in event_files:
|
||||
if not event_file.endswith(".tld"):
|
||||
continue
|
||||
eventFilePath = os.path.join(self.eventDir, eventFile)
|
||||
event_file_path = os.path.join(self.event_dir, event_file)
|
||||
try:
|
||||
dataStr = self.collectEvent(eventFilePath)
|
||||
data_str = self.collect_event(event_file_path)
|
||||
except EventError as e:
|
||||
logger.Error("{0}", e)
|
||||
logger.error("{0}", e)
|
||||
continue
|
||||
try:
|
||||
data = json.loads(dataStr)
|
||||
data = json.loads(data_str)
|
||||
except ValueError as e:
|
||||
logger.Verbose(dataStr)
|
||||
logger.Error("Failed to decode json event file{0}", e)
|
||||
logger.verb(data_str)
|
||||
logger.error("Failed to decode json event file{0}", e)
|
||||
continue
|
||||
|
||||
event = prot.TelemetryEvent()
|
||||
prot.set_properties(event, data)
|
||||
event.parameters.extend(self.sysInfo)
|
||||
eventList.events.append(event)
|
||||
if len(eventList.events) == 0:
|
||||
event.parameters.extend(self.sysinfo)
|
||||
event_list.events.append(event)
|
||||
if len(event_list.events) == 0:
|
||||
return
|
||||
|
||||
try:
|
||||
protocol = prot.Factory.getDefaultProtocol()
|
||||
protocol.reportEvent(eventList)
|
||||
protocol = prot.Factory.get_default_protocol()
|
||||
protocol.report_event(event_list)
|
||||
except prot.ProtocolError as e:
|
||||
logger.Error("{0}", e)
|
||||
logger.error("{0}", e)
|
||||
|
||||
def eventsLoop(self):
|
||||
lastHeatbeat = datetime.datetime.min
|
||||
def run(self):
|
||||
last_heartbeat = datetime.datetime.min
|
||||
period = datetime.timedelta(hours = 12)
|
||||
while(True):
|
||||
if (datetime.datetime.now()-lastHeatbeat) > period:
|
||||
lastHeatbeat = datetime.datetime.now()
|
||||
AddExtensionEvent(op=WALAEventOperation.HeartBeat,
|
||||
name="WALA",isSuccess=True)
|
||||
self.collectAndSendEvents()
|
||||
if (datetime.datetime.now()-last_heartbeat) > period:
|
||||
last_heartbeat = datetime.datetime.now()
|
||||
add_event(op=WALAEventOperation.HeartBeat,
|
||||
name="WALA",is_success=True)
|
||||
self.collect_and_send_events()
|
||||
time.sleep(60)
|
||||
|
||||
def SaveEvent(data):
|
||||
eventfolder = os.path.join(OSUtil.GetLibDir(), 'events')
|
||||
if not os.path.exists(eventfolder):
|
||||
os.mkdir(eventfolder)
|
||||
os.chmod(eventfolder,0700)
|
||||
if len(os.listdir(eventfolder)) > 1000:
|
||||
raise EventError("Too many files under: {0}", eventfolder)
|
||||
|
||||
filename = os.path.join(eventfolder, str(int(time.time()*1000000)))
|
||||
def save_event(data):
|
||||
event_dir = os.path.join(OSUTIL.get_lib_dir(), 'events')
|
||||
if not os.path.exists(event_dir):
|
||||
os.mkdir(event_dir)
|
||||
os.chmod(event_dir,0700)
|
||||
if len(os.listdir(event_dir)) > 1000:
|
||||
raise EventError("Too many files under: {0}", event_dir)
|
||||
|
||||
filename = os.path.join(event_dir, str(int(time.time()*1000000)))
|
||||
try:
|
||||
with open(filename+".tmp",'wb+') as hfile:
|
||||
hfile.write(data.encode("utf-8"))
|
||||
@@ -152,38 +152,38 @@ def SaveEvent(data):
|
||||
raise EventError("Failed to write events to file:{0}", e)
|
||||
|
||||
|
||||
def AddExtensionEvent(name, op, isSuccess, duration=0, version="1.0",
|
||||
message="", evtType="", isInternal=False):
|
||||
def add_event(name, op, is_success, duration=0, version="1.0",
|
||||
message="", evt_type="", is_internal=False):
|
||||
event = prot.TelemetryEvent(1, "69B669B9-4AF8-4C50-BDC4-6006FA76E975")
|
||||
event.parameters.append(prot.TelemetryEventParam('Name', name))
|
||||
event.parameters.append(prot.TelemetryEventParam('Version', version))
|
||||
event.parameters.append(prot.TelemetryEventParam('IsInternal', isInternal))
|
||||
event.parameters.append(prot.TelemetryEventParam('Operation', op))
|
||||
event.parameters.append(prot.TelemetryEventParam('OperationSuccess',
|
||||
isSuccess))
|
||||
event.parameters.append(prot.TelemetryEventParam('Message', message))
|
||||
event.parameters.append(prot.TelemetryEventParam('Duration', duration))
|
||||
event.parameters.append(prot.TelemetryEventParam('ExtensionType', evtType))
|
||||
|
||||
event.parameters.append(prot.TelemetryEventParam('Name', name))
|
||||
event.parameters.append(prot.TelemetryEventParam('Version', version))
|
||||
event.parameters.append(prot.TelemetryEventParam('IsInternal', is_internal))
|
||||
event.parameters.append(prot.TelemetryEventParam('Operation', op))
|
||||
event.parameters.append(prot.TelemetryEventParam('OperationSuccess',
|
||||
is_success))
|
||||
event.parameters.append(prot.TelemetryEventParam('Message', message))
|
||||
event.parameters.append(prot.TelemetryEventParam('Duration', duration))
|
||||
event.parameters.append(prot.TelemetryEventParam('ExtensionType', evt_type))
|
||||
|
||||
data = prot.get_properties(event)
|
||||
try:
|
||||
SaveEvent(json.dumps(data))
|
||||
save_event(json.dumps(data))
|
||||
except EventError as e:
|
||||
logger.Error("{0}", e)
|
||||
logger.error("{0}", e)
|
||||
|
||||
def DumpUnhandledError(name):
|
||||
def dump_unhandled_err(name):
|
||||
if hasattr(sys, 'last_type') and hasattr(sys, 'last_value') and \
|
||||
hasattr(sys, 'last_traceback'):
|
||||
last_type = getattr(sys, 'last_type')
|
||||
last_value = getattr(sys, 'last_value')
|
||||
last_traceback = getattr(sys, 'last_traceback')
|
||||
error = traceback.format_exception(last_type, last_value,
|
||||
error = traceback.format_exception(last_type, last_value,
|
||||
last_traceback)
|
||||
message= "".join(error)
|
||||
logger.Error(message)
|
||||
AddExtensionEvent(name, isSuccess=False, message=message,
|
||||
logger.error(message)
|
||||
add_event(name, is_success=False, message=message,
|
||||
op=WALAEventOperation.UnhandledError)
|
||||
|
||||
def EnableUnhandledErrorDump(name):
|
||||
atexit.register(DumpUnhandledError, name)
|
||||
def enable_unhandled_err_dump(name):
|
||||
atexit.register(dump_unhandled_err, name)
|
||||
|
||||
|
||||
@@ -16,46 +16,50 @@
|
||||
#
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
"""
|
||||
Defines all exceptions
|
||||
"""
|
||||
|
||||
"""
|
||||
Base class of agent error.
|
||||
"""
|
||||
class AgentError(Exception):
|
||||
"""
|
||||
Base class of agent error.
|
||||
"""
|
||||
def __init__(self, errno, msg):
|
||||
msg = "({0}){1}".format(errno, msg)
|
||||
super(AgentError, self).__init__(msg)
|
||||
|
||||
"""
|
||||
When configure file is not found or malformed.
|
||||
"""
|
||||
class AgentConfigError(AgentError):
|
||||
"""
|
||||
When configure file is not found or malformed.
|
||||
"""
|
||||
def __init__(self, msg):
|
||||
super(AgentConfigError, self).__init__('000001', msg)
|
||||
|
||||
"""
|
||||
When network is not avaiable.
|
||||
"""
|
||||
class AgentNetworkError(AgentError):
|
||||
"""
|
||||
When network is not avaiable.
|
||||
"""
|
||||
def __init__(self, msg):
|
||||
super(AgentNetworkError, self).__init__('000002', msg)
|
||||
|
||||
"""
|
||||
When failed to execute an extension
|
||||
"""
|
||||
class ExtensionError(AgentError):
|
||||
"""
|
||||
When failed to execute an extension
|
||||
"""
|
||||
def __init__(self, msg):
|
||||
super(ExtensionError, self).__init__('000003', msg)
|
||||
|
||||
"""
|
||||
When provision failed
|
||||
"""
|
||||
class ProvisionError(AgentError):
|
||||
"""
|
||||
When provision failed
|
||||
"""
|
||||
def __init__(self, msg):
|
||||
super(ProvisionError, self).__init__('000004', msg)
|
||||
"""
|
||||
Mount resource disk failed
|
||||
"""
|
||||
|
||||
class ResourceDiskError(AgentError):
|
||||
"""
|
||||
Mount resource disk failed
|
||||
"""
|
||||
def __init__(self, msg):
|
||||
super(ResourceDiskError, self).__init__('000005', msg)
|
||||
|
||||
|
||||
@@ -18,9 +18,11 @@
|
||||
#
|
||||
|
||||
"""
|
||||
Load OSUtil implementation from azurelinuxagent.distro
|
||||
Handler handles different tasks like, provisioning, deprovisioning etc.
|
||||
The handlers could be extended for different distros. The default
|
||||
implementation is under azurelinuxagent.distros.default
|
||||
"""
|
||||
import azurelinuxagent.distro.loader as loader
|
||||
|
||||
Handlers = loader.GetHandlers()
|
||||
HANDLERS = loader.get_handlers()
|
||||
|
||||
|
||||
+45
-53
@@ -17,13 +17,18 @@
|
||||
# Implements parts of RFC 2131, 1541, 1497 and
|
||||
# http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx
|
||||
# http://msdn.microsoft.com/en-us/library/cc227259%28PROT.13%29.aspx
|
||||
"""
|
||||
Log utils
|
||||
"""
|
||||
|
||||
import sys
|
||||
import traceback
|
||||
import azurelinuxagent.utils.textutil as textutil
|
||||
from datetime import datetime
|
||||
|
||||
class Logger(object):
|
||||
"""
|
||||
Logger class
|
||||
"""
|
||||
def __init__(self, logger=None, prefix=None):
|
||||
self.appenders = []
|
||||
if logger is not None:
|
||||
@@ -43,23 +48,24 @@ class Logger(object):
|
||||
self.log(LogLevel.ERROR, msg_format, *args)
|
||||
|
||||
def log(self, level, msg_format, *args):
|
||||
msg_format = textutil.Ascii(msg_format)
|
||||
args = map(lambda x : textutil.Ascii(x), args)
|
||||
if(len(args) > 0):
|
||||
msg_format = textutil.ascii(msg_format)
|
||||
args = map(lambda x: textutil.ascii(x), args)
|
||||
if len(args) > 0:
|
||||
msg = msg_format.format(*args)
|
||||
else:
|
||||
msg = msg_format
|
||||
time = datetime.now().strftime('%Y/%m/%d %H:%M:%S.%f')
|
||||
levelStr = LogLevel.STRINGS[level]
|
||||
level_str = LogLevel.STRINGS[level]
|
||||
if self.prefix is not None:
|
||||
logItem = "{0} {1} {2} {3}".format(time, levelStr, self.prefix, msg)
|
||||
log_item = "{0} {1} {2} {3}".format(time, level_str, self.prefix,
|
||||
msg)
|
||||
else:
|
||||
logItem = "{0} {1} {2}".format(time, levelStr, msg)
|
||||
log_item = "{0} {1} {2}".format(time, level_str, msg)
|
||||
for appender in self.appenders:
|
||||
appender.write(level, logItem)
|
||||
appender.write(level, log_item)
|
||||
|
||||
def addLoggerAppender(self, appenderType, level, path):
|
||||
appender = CreateLoggerAppender(appenderType, level, path)
|
||||
def add_appender(self, appender_type, level, path):
|
||||
appender = _create_logger_appender(appender_type, level, path)
|
||||
self.appenders.append(appender)
|
||||
|
||||
class ConsoleAppender(object):
|
||||
@@ -68,15 +74,15 @@ class ConsoleAppender(object):
|
||||
if level >= LogLevel.INFO:
|
||||
self.level = level
|
||||
self.path = path
|
||||
|
||||
|
||||
def write(self, level, msg):
|
||||
if self.level <= level:
|
||||
try:
|
||||
with open(self.path, "w") as console :
|
||||
console.write(msg.encode('ascii','ignore') + "\n")
|
||||
except IOError as e:
|
||||
with open(self.path, "w") as console:
|
||||
console.write(msg.encode('ascii', 'ignore') + "\n")
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
|
||||
class FileAppender(object):
|
||||
def __init__(self, level, path):
|
||||
self.level = level
|
||||
@@ -86,8 +92,8 @@ class FileAppender(object):
|
||||
if self.level <= level:
|
||||
try:
|
||||
with open(self.path, "a+") as log_file:
|
||||
log_file.write(msg.encode('ascii','ignore') + "\n")
|
||||
except IOError as e:
|
||||
log_file.write(msg.encode('ascii', 'ignore') + "\n")
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
class StdoutAppender(object):
|
||||
@@ -97,13 +103,13 @@ class StdoutAppender(object):
|
||||
def write(self, level, msg):
|
||||
if self.level <= level:
|
||||
try:
|
||||
sys.stdout.write(msg.encode('ascii','ignore') + "\n")
|
||||
except IOError as e:
|
||||
sys.stdout.write(msg.encode('ascii', 'ignore') + "\n")
|
||||
except IOError:
|
||||
pass
|
||||
|
||||
|
||||
#Initialize logger instance
|
||||
DefaultLogger = Logger()
|
||||
default_logger = Logger()
|
||||
|
||||
class LogLevel(object):
|
||||
VERBOSE = 0
|
||||
@@ -118,49 +124,35 @@ class LogLevel(object):
|
||||
]
|
||||
|
||||
class AppenderType(object):
|
||||
FILE=0
|
||||
CONSOLE=1
|
||||
STDOUT=2
|
||||
FILE = 0
|
||||
CONSOLE = 1
|
||||
STDOUT = 2
|
||||
|
||||
def AddLoggerAppender(appenderType, level=LogLevel.INFO, path=None):
|
||||
DefaultLogger.addLoggerAppender(appenderType, level, path)
|
||||
def add_logger_appender(appender_type, level=LogLevel.INFO, path=None):
|
||||
default_logger.add_appender(appender_type, level, path)
|
||||
|
||||
def Verbose(msg_format, *args):
|
||||
DefaultLogger.verbose(msg_format, *args)
|
||||
def verb(msg_format, *args):
|
||||
default_logger.verbose(msg_format, *args)
|
||||
|
||||
def Info(msg_format, *args):
|
||||
DefaultLogger.info(msg_format, *args)
|
||||
def info(msg_format, *args):
|
||||
default_logger.info(msg_format, *args)
|
||||
|
||||
def Warn(msg_format, *args):
|
||||
DefaultLogger.warn(msg_format, *args)
|
||||
def warn(msg_format, *args):
|
||||
default_logger.warn(msg_format, *args)
|
||||
|
||||
def Error(msg_format, *args):
|
||||
DefaultLogger.error(msg_format, *args)
|
||||
def error(msg_format, *args):
|
||||
default_logger.error(msg_format, *args)
|
||||
|
||||
def Log(level, msg_format, *args):
|
||||
DefaultLogger.log(level, msg_format, args)
|
||||
def log(level, msg_format, *args):
|
||||
default_logger.log(level, msg_format, args)
|
||||
|
||||
def CreateLoggerAppender(appenderType, level=LogLevel.INFO, path=None):
|
||||
if appenderType == AppenderType.CONSOLE :
|
||||
def _create_logger_appender(appender_type, level=LogLevel.INFO, path=None):
|
||||
if appender_type == AppenderType.CONSOLE:
|
||||
return ConsoleAppender(level, path)
|
||||
elif appenderType == AppenderType.FILE :
|
||||
elif appender_type == AppenderType.FILE:
|
||||
return FileAppender(level, path)
|
||||
elif appenderType == AppenderType.STDOUT :
|
||||
elif appender_type == AppenderType.STDOUT:
|
||||
return StdoutAppender(level)
|
||||
else:
|
||||
raise ValueError("Unknown appender type")
|
||||
|
||||
def LogError(operation):
|
||||
def Decorator(func):
|
||||
def Wrapper(*args, **kwargs):
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
except Exception, e:
|
||||
Error("Failed to {0} :{1} {2}",
|
||||
operation,
|
||||
e,
|
||||
traceback.format_exc())
|
||||
raise e
|
||||
return result
|
||||
return Wrapper
|
||||
return Decorator
|
||||
|
||||
+21
-21
@@ -21,40 +21,40 @@ import os
|
||||
import re
|
||||
import platform
|
||||
|
||||
def GetDistroInfo():
|
||||
def get_distro():
|
||||
if 'FreeBSD' in platform.system():
|
||||
release = re.sub('\-.*\Z', '', str(platform.release()))
|
||||
osInfo = ['freebsd', release, '', 'freebsd']
|
||||
osinfo = ['freebsd', release, '', 'freebsd']
|
||||
if 'linux_distribution' in dir(platform):
|
||||
osInfo = list(platform.linux_distribution(full_distribution_name=0))
|
||||
fullName = platform.linux_distribution()[0].strip()
|
||||
osInfo.append(fullName)
|
||||
osinfo = list(platform.linux_distribution(full_distribution_name=0))
|
||||
full_name = platform.linux_distribution()[0].strip()
|
||||
osinfo.append(full_name)
|
||||
else:
|
||||
osInfo = platform.dist()
|
||||
osinfo = platform.dist()
|
||||
|
||||
#The platform.py lib has issue with detecting oracle linux distribution.
|
||||
#Merge the following patch provided by oracle as a temparory fix.
|
||||
if os.path.exists("/etc/oracle-release"):
|
||||
osInfo[2]="oracle"
|
||||
osInfo[3]="Oracle Linux"
|
||||
if os.path.exists("/etc/oracle-release"):
|
||||
osinfo[2] = "oracle"
|
||||
osinfo[3] = "Oracle Linux"
|
||||
|
||||
#Remove trailing whitespace and quote in distro name
|
||||
osInfo[0] = osInfo[0].strip('"').strip(' ').lower()
|
||||
return osInfo
|
||||
osinfo[0] = osinfo[0].strip('"').strip(' ').lower()
|
||||
return osinfo
|
||||
|
||||
GuestAgentName = "AzureLinuxAgent"
|
||||
GuestAgentLongName = "Azure Linux Agent"
|
||||
GuestAgentVersion='2.1.0'
|
||||
GuestAgentLongVersion = "{0}-{1}".format(GuestAgentName, GuestAgentVersion)
|
||||
GuestAgentDescription = """\
|
||||
AGENT_NAME = "AzureLinuxAgent"
|
||||
agent_long_name = "Azure Linux Agent"
|
||||
AGENT_VERSION = '2.1.0'
|
||||
agent_long_version = "{0}-{1}".format(AGENT_NAME, AGENT_VERSION)
|
||||
agent_description = """\
|
||||
The Azure Linux Agent supports the provisioning and running of Linux
|
||||
VMs in the Azure cloud. This package should be installed on Linux disk
|
||||
images that are built to run in the Azure environment.
|
||||
"""
|
||||
|
||||
__DistroInfo = GetDistroInfo()
|
||||
DistroName = __DistroInfo[0]
|
||||
DistroVersion = __DistroInfo[1]
|
||||
DistroCodeName = __DistroInfo[2]
|
||||
DistroFullName = __DistroInfo[3]
|
||||
__distro__ = get_distro()
|
||||
DISTRO_NAME = __distro__[0]
|
||||
DISTRO_VERSION = __distro__[1]
|
||||
DISTRO_CODE_NAME = __distro__[2]
|
||||
DISTRO_FULL_NAME = __distro__[3]
|
||||
|
||||
|
||||
@@ -19,5 +19,5 @@
|
||||
|
||||
from azurelinuxagent.protocol.common import *
|
||||
from azurelinuxagent.protocol.protocolFactory import Factory, \
|
||||
DetectDefaultProtocol
|
||||
detect_default_protocol
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ import re
|
||||
import json
|
||||
import xml.dom.minidom
|
||||
import azurelinuxagent.logger as logger
|
||||
from azurelinuxagent.utils.textutil import GetNodeTextData
|
||||
from azurelinuxagent.utils.textutil import get_node_text
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
|
||||
class ProtocolError(Exception):
|
||||
@@ -31,12 +31,12 @@ class ProtocolError(Exception):
|
||||
class ProtocolNotFound(Exception):
|
||||
pass
|
||||
|
||||
def validata_param(name, val, expectedType):
|
||||
def validata_param(name, val, expected_type):
|
||||
if val is None:
|
||||
raise ProtocolError("Param {0} is None".format(name))
|
||||
if not isinstance(val, expectedType):
|
||||
if not isinstance(val, expected_type):
|
||||
raise ProtocolError("Param {0} type should be {1}".format(name,
|
||||
expectedType))
|
||||
expected_type))
|
||||
|
||||
def set_properties(obj, data):
|
||||
validata_param("obj", obj, DataContract)
|
||||
@@ -45,20 +45,20 @@ def set_properties(obj, data):
|
||||
props = vars(obj)
|
||||
for name, val in props.items():
|
||||
try:
|
||||
newVal = data[name]
|
||||
new_val = data[name]
|
||||
except KeyError:
|
||||
continue
|
||||
|
||||
if isinstance(newVal, dict):
|
||||
set_properties(val, newVal)
|
||||
elif isinstance(newVal, list):
|
||||
if isinstance(new_val, dict):
|
||||
set_properties(val, new_val)
|
||||
elif isinstance(new_val, list):
|
||||
validata_param("list", val, DataContractList)
|
||||
for dataItem in newVal:
|
||||
item = val.itemType()
|
||||
set_properties(item, dataItem)
|
||||
for data_item in new_val:
|
||||
item = val.itemType()
|
||||
set_properties(item, data_item)
|
||||
val.append(item)
|
||||
else:
|
||||
setattr(obj, name, newVal)
|
||||
setattr(obj, name, new_val)
|
||||
|
||||
def get_properties(obj):
|
||||
validata_param("obj", obj, DataContract)
|
||||
@@ -86,7 +86,7 @@ class DataContractList(list):
|
||||
def __init__(self, itemType):
|
||||
self.itemType = itemType
|
||||
|
||||
class VmInfo(DataContract):
|
||||
class VMInfo(DataContract):
|
||||
def __init__(self, subscriptionId=None, vmName=None):
|
||||
self.subscriptionId = subscriptionId
|
||||
self.vmName = vmName
|
||||
@@ -179,14 +179,14 @@ class ExtensionSubStatus(DataContract):
|
||||
|
||||
class ExtensionStatus(DataContract):
|
||||
def __init__(self, name=None, configurationAppliedTime=None, operation=None,
|
||||
status=None, code=None, message=None, sequenceNumber=None):
|
||||
status=None, code=None, message=None, seq_no=None):
|
||||
self.name = name
|
||||
self.configurationAppliedTime = configurationAppliedTime
|
||||
self.operation = operation
|
||||
self.status = status
|
||||
self.code = code
|
||||
self.message = message
|
||||
self.sequenceNumber = sequenceNumber
|
||||
self.sequenceNumber = seq_no
|
||||
self.substatusList = DataContractList(ExtensionSubStatus)
|
||||
|
||||
class ExtensionHandlerStatus(DataContract):
|
||||
@@ -223,27 +223,27 @@ class Protocol(DataContract):
|
||||
def initialize(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def getVmInfo(self):
|
||||
def get_vminfo(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def getCerts(self):
|
||||
def get_certs(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def getExtensions(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def getExtensionPackages(self, extension):
|
||||
raise NotImplementedError()
|
||||
|
||||
def getInstanceMetadata(self):
|
||||
def get_extensions(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def reportProvisionStatus(self, status):
|
||||
def get_extension_pkgs(self, extension):
|
||||
raise NotImplementedError()
|
||||
|
||||
def reportStatus(self, status):
|
||||
def get_instance_metadata(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def reportEvent(self, event):
|
||||
def report_provision_status(self, status):
|
||||
raise NotImplementedError()
|
||||
|
||||
def report_status(self, status):
|
||||
raise NotImplementedError()
|
||||
|
||||
def report_event(self, event):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -19,161 +19,161 @@
|
||||
|
||||
from azurelinuxagent.protocol.common import *
|
||||
|
||||
from azurelinuxagent.utils.osutil import OSUtil, OSUtilError
|
||||
from azurelinuxagent.utils.osutil import OSUTIL, OSUtilError
|
||||
|
||||
def GetOvfEnv():
|
||||
ovfFilePath = os.path.join(OSUtil.GetLibDir(), OvfFileName)
|
||||
if os.path.isfile(ovfFilePath):
|
||||
xmlText = fileutil.GetFileContents(ovfFilePath)
|
||||
return OvfEnv(xmlText)
|
||||
def get_ovf_env():
|
||||
ovf_file_path = os.path.join(OSUTIL.get_lib_dir(), OVF_FILE_NAME)
|
||||
if os.path.isfile(ovf_file_path):
|
||||
xml_text = fileutil.read_file(ovf_file_path)
|
||||
return OvfEnv(xml_text)
|
||||
else:
|
||||
raise ProtocolError("ovf-env.xml is missing.")
|
||||
|
||||
def CopyOvfEnv():
|
||||
def copy_ovf_env():
|
||||
"""
|
||||
Copy ovf env file from dvd to hard disk.
|
||||
Copy ovf env file from dvd to hard disk.
|
||||
Remove password before save it to the disk
|
||||
"""
|
||||
try:
|
||||
OSUtil.MountDvd()
|
||||
ovfFile = OSUtil.GetOvfEnvPathOnDvd()
|
||||
OSUTIL.mount_dvd()
|
||||
ovf_file_path_on_dvd = OSUTIL.get_ovf_env_file_path_on_dvd()
|
||||
|
||||
ovfxml = fileutil.GetFileContents(ovfFile, removeBom=True)
|
||||
ovfxml = fileutil.read_file(ovf_file_path_on_dvd, remove_bom=True)
|
||||
ovfenv = OvfEnv(ovfxml)
|
||||
ovfxml = re.sub("<UserPassword>.*?<", "<UserPassword>*<", ovfxml)
|
||||
ovfFilePath = os.path.join(OSUtil.GetLibDir(), OvfFileName)
|
||||
fileutil.SetFileContents(ovfFilePath, ovfxml)
|
||||
OSUtil.UmountDvd()
|
||||
ovf_file_path = os.path.join(OSUTIL.get_lib_dir(), OVF_FILE_NAME)
|
||||
fileutil.write_file(ovf_file_path, ovfxml)
|
||||
OSUTIL.umount_dvd()
|
||||
except IOError as e:
|
||||
raise ProtocolError(str(e))
|
||||
except OSUtilError as e:
|
||||
raise ProtocolError(str(e))
|
||||
return ovfenv
|
||||
|
||||
OvfFileName="ovf-env.xml"
|
||||
OVF_FILE_NAME="ovf-env.xml"
|
||||
class OvfEnv(object):
|
||||
"""
|
||||
Read, and process provisioning info from provisioning file OvfEnv.xml
|
||||
"""
|
||||
def __init__(self, xmlText):
|
||||
if xmlText is None:
|
||||
def __init__(self, xml_text):
|
||||
if xml_text is None:
|
||||
raise ValueError("ovf-env is None")
|
||||
logger.Verbose("Load ovf-env.xml")
|
||||
self.parse(xmlText)
|
||||
logger.verb("Load ovf-env.xml")
|
||||
self.parse(xml_text)
|
||||
|
||||
def reinitialize(self):
|
||||
"""
|
||||
Reset members.
|
||||
"""
|
||||
self.WaNs = "http://schemas.microsoft.com/windowsazure"
|
||||
self.OvfNs = "http://schemas.dmtf.org/ovf/environment/1"
|
||||
self.MajorVersion = 1
|
||||
self.MinorVersion = 0
|
||||
self.ComputerName = None
|
||||
self.UserName = None
|
||||
self.UserPassword = None
|
||||
self.CustomData = None
|
||||
self.DisableSshPasswordAuthentication = True
|
||||
self.SshPublicKeys = []
|
||||
self.SshKeyPairs = []
|
||||
self.wa_ns = "http://schemas.microsoft.com/windowsazure"
|
||||
self.ovf_ns = "http://schemas.dmtf.org/ovf/environment/1"
|
||||
self.major_version = 1
|
||||
self.minor_version = 0
|
||||
self.compute_name = None
|
||||
self.user_name = None
|
||||
self.user_password = None
|
||||
self.customdata = None
|
||||
self.disable_ssh_password_auth = True
|
||||
self.ssh_pubkeys = []
|
||||
self.ssh_keypairs = []
|
||||
|
||||
def getMajorVersion(self):
|
||||
return self.MajorVersion
|
||||
def get_major_version(self):
|
||||
return self.major_version
|
||||
|
||||
def getMinorVersion(self):
|
||||
return self.MinorVersion
|
||||
def get_minor_version(self):
|
||||
return self.minor_version
|
||||
|
||||
def getComputerName(self):
|
||||
return self.ComputerName
|
||||
def get_computer_name(self):
|
||||
return self.compute_name
|
||||
|
||||
def getUserName(self):
|
||||
return self.UserName
|
||||
def get_username(self):
|
||||
return self.user_name
|
||||
|
||||
def getUserPassword(self):
|
||||
return self.UserPassword
|
||||
def get_user_password(self):
|
||||
return self.user_password
|
||||
|
||||
def clearUserPassword(self):
|
||||
self.UserPassword = None
|
||||
def clear_user_password(self):
|
||||
self.user_password = None
|
||||
|
||||
def getCustomData(self):
|
||||
return self.CustomData
|
||||
def get_customdata(self):
|
||||
return self.customdata
|
||||
|
||||
def getDisableSshPasswordAuthentication(self):
|
||||
return self.DisableSshPasswordAuthentication
|
||||
def get_disable_ssh_password_auth(self):
|
||||
return self.disable_ssh_password_auth
|
||||
|
||||
def getSshPublicKeys(self):
|
||||
return self.SshPublicKeys
|
||||
def get_ssh_pubkeys(self):
|
||||
return self.ssh_pubkeys
|
||||
|
||||
def getSshKeyPairs(self):
|
||||
return self.SshKeyPairs
|
||||
def get_ssh_keypairs(self):
|
||||
return self.ssh_keypairs
|
||||
|
||||
def parse(self, xmlText):
|
||||
def parse(self, xml_text):
|
||||
"""
|
||||
Parse xml tree, retreiving user and ssh key information.
|
||||
Return self.
|
||||
"""
|
||||
self.reinitialize()
|
||||
dom = xml.dom.minidom.parseString(xmlText)
|
||||
if len(dom.getElementsByTagNameNS(self.OvfNs, "Environment")) != 1:
|
||||
logger.Error("Unable to parse OVF XML.")
|
||||
dom = xml.dom.minidom.parseString(xml_text)
|
||||
if len(dom.getElementsByTagNameNS(self.ovf_ns, "Environment")) != 1:
|
||||
logger.error("Unable to parse OVF XML.")
|
||||
section = None
|
||||
newer = False
|
||||
for p in dom.getElementsByTagNameNS(self.WaNs, "ProvisioningSection"):
|
||||
for p in dom.getElementsByTagNameNS(self.wa_ns, "ProvisioningSection"):
|
||||
for n in p.childNodes:
|
||||
if n.localName == "Version":
|
||||
verparts = GetNodeTextData(n).split('.')
|
||||
verparts = get_node_text(n).split('.')
|
||||
major = int(verparts[0])
|
||||
minor = int(verparts[1])
|
||||
if major > self.MajorVersion:
|
||||
if major > self.major_version:
|
||||
newer = True
|
||||
if major != self.MajorVersion:
|
||||
if major != self.major_version:
|
||||
break
|
||||
if minor > self.MinorVersion:
|
||||
if minor > self.minor_version:
|
||||
newer = True
|
||||
section = p
|
||||
if newer == True:
|
||||
logger.Warn("Newer provisioning configuration detected. "
|
||||
logger.warn("Newer provisioning configuration detected. "
|
||||
"Please consider updating waagent.")
|
||||
if section == None:
|
||||
logger.Error("Could not find ProvisioningSection with "
|
||||
"major version={0}", self.MajorVersion)
|
||||
logger.error("Could not find ProvisioningSection with "
|
||||
"major version={0}", self.major_version)
|
||||
return None
|
||||
self.ComputerName = GetNodeTextData(section.getElementsByTagNameNS(self.WaNs, "HostName")[0])
|
||||
self.UserName = GetNodeTextData(section.getElementsByTagNameNS(self.WaNs, "UserName")[0])
|
||||
self.compute_name = get_node_text(section.getElementsByTagNameNS(self.wa_ns, "HostName")[0])
|
||||
self.user_name = get_node_text(section.getElementsByTagNameNS(self.wa_ns, "UserName")[0])
|
||||
try:
|
||||
self.UserPassword = GetNodeTextData(section.getElementsByTagNameNS(self.WaNs, "UserPassword")[0])
|
||||
self.user_password = get_node_text(section.getElementsByTagNameNS(self.wa_ns, "UserPassword")[0])
|
||||
except:
|
||||
pass
|
||||
CDSection=None
|
||||
CDSection=section.getElementsByTagNameNS(self.WaNs, "CustomData")
|
||||
if len(CDSection) > 0 :
|
||||
self.CustomData=GetNodeTextData(CDSection[0])
|
||||
disableSshPass = section.getElementsByTagNameNS(self.WaNs, "DisableSshPasswordAuthentication")
|
||||
if len(disableSshPass) != 0:
|
||||
self.DisableSshPasswordAuthentication = (GetNodeTextData(disableSshPass[0]).lower() == "true")
|
||||
for pkey in section.getElementsByTagNameNS(self.WaNs, "PublicKey"):
|
||||
logger.Verbose(repr(pkey))
|
||||
cd_section=None
|
||||
cd_section=section.getElementsByTagNameNS(self.wa_ns, "CustomData")
|
||||
if len(cd_section) > 0 :
|
||||
self.customdata=get_node_text(cd_section[0])
|
||||
disable_ssh_password_auth = section.getElementsByTagNameNS(self.wa_ns, "DisableSshPasswordAuthentication")
|
||||
if len(disable_ssh_password_auth) != 0:
|
||||
self.disable_ssh_password_auth = (get_node_text(disable_ssh_password_auth[0]).lower() == "true")
|
||||
for pkey in section.getElementsByTagNameNS(self.wa_ns, "PublicKey"):
|
||||
logger.verb(repr(pkey))
|
||||
fp = None
|
||||
path = None
|
||||
for c in pkey.childNodes:
|
||||
if c.localName == "Fingerprint":
|
||||
fp = GetNodeTextData(c).upper()
|
||||
logger.Verbose(fp)
|
||||
fp = get_node_text(c).upper()
|
||||
logger.verb(fp)
|
||||
if c.localName == "Path":
|
||||
path = GetNodeTextData(c)
|
||||
logger.Verbose(path)
|
||||
self.SshPublicKeys += [[fp, path]]
|
||||
for keyp in section.getElementsByTagNameNS(self.WaNs, "KeyPair"):
|
||||
path = get_node_text(c)
|
||||
logger.verb(path)
|
||||
self.ssh_pubkeys += [[fp, path]]
|
||||
for keyp in section.getElementsByTagNameNS(self.wa_ns, "KeyPair"):
|
||||
fp = None
|
||||
path = None
|
||||
logger.Verbose(repr(keyp))
|
||||
logger.verb(repr(keyp))
|
||||
for c in keyp.childNodes:
|
||||
if c.localName == "Fingerprint":
|
||||
fp = GetNodeTextData(c).upper()
|
||||
logger.Verbose(fp)
|
||||
fp = get_node_text(c).upper()
|
||||
logger.verb(fp)
|
||||
if c.localName == "Path":
|
||||
path = GetNodeTextData(c)
|
||||
logger.Verbose(path)
|
||||
self.SshKeyPairs += [[fp, path]]
|
||||
path = get_node_text(c)
|
||||
logger.verb(path)
|
||||
self.ssh_keypairs += [[fp, path]]
|
||||
return self
|
||||
|
||||
|
||||
@@ -21,19 +21,19 @@ import traceback
|
||||
import threading
|
||||
import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
from azurelinuxagent.protocol.common import *
|
||||
from azurelinuxagent.protocol.v1 import ProtocolV1
|
||||
from azurelinuxagent.protocol.v2 import ProtocolV2
|
||||
from azurelinuxagent.protocol.v1 import WIRE_PROTOCOL
|
||||
from azurelinuxagent.protocol.v2 import MetadataProtocol
|
||||
|
||||
WireServerAddrFile = "WireServer"
|
||||
WireProtocol = "WireProtocol"
|
||||
MetaDataProtocol = "MetaDataProtocol"
|
||||
WIRE_SERVER_ADDR_FILE_NAME = "WireServer"
|
||||
WIRE_PROTOCOL = "WireProtocol"
|
||||
METADATA_PROTOCOL = "MetaDataProtocol"
|
||||
|
||||
def GetWireProtocolEndpoint():
|
||||
path = os.path.join(OSUtil.GetLibDir(), WireServerAddrFile)
|
||||
def get_wire_protocol_endpoint():
|
||||
path = os.path.join(OSUTIL.get_lib_dir(), WIRE_SERVER_ADDR_FILE_NAME)
|
||||
try:
|
||||
endpoint = fileutil.GetFileContents(path)
|
||||
endpoint = fileutil.read_file(path)
|
||||
except IOError as e:
|
||||
raise ProtocolNotFound("Wire server endpoint not found: {0}".format(e))
|
||||
|
||||
@@ -42,84 +42,84 @@ def GetWireProtocolEndpoint():
|
||||
|
||||
return endpoint
|
||||
|
||||
def DetectV1():
|
||||
endpoint = GetWireProtocolEndpoint()
|
||||
def detect_wire_protocol():
|
||||
endpoint = get_wire_protocol_endpoint()
|
||||
|
||||
OSUtil.GenerateTransportCert()
|
||||
protocol = ProtocolV1(endpoint)
|
||||
protocol.initialize()
|
||||
|
||||
logger.Info("Protocol V1 found.")
|
||||
path = os.path.join(OSUtil.GetLibDir(), WireProtocol)
|
||||
|
||||
fileutil.SetFileContents(path, "")
|
||||
return protocol
|
||||
|
||||
def DetectV2():
|
||||
protocol = ProtocolV2()
|
||||
OSUTIL.gen_transport_cert()
|
||||
protocol = WIRE_PROTOCOL(endpoint)
|
||||
protocol.initialize()
|
||||
|
||||
logger.Info("Protocol V2 found.")
|
||||
path = os.path.join(OSUtil.GetLibDir(), MetaDataProtocol)
|
||||
fileutil.SetFileContents(path, "")
|
||||
logger.info("Protocol V1 found.")
|
||||
path = os.path.join(OSUTIL.get_lib_dir(), WIRE_PROTOCOL)
|
||||
|
||||
fileutil.write_file(path, "")
|
||||
return protocol
|
||||
|
||||
def DetectAvailableProtocols(probeFuncs=[DetectV1, DetectV2]):
|
||||
availableProtocols = []
|
||||
for probeFunc in probeFuncs:
|
||||
def detect_metadata_protocol():
|
||||
protocol = MetadataProtocol()
|
||||
protocol.initialize()
|
||||
|
||||
logger.info("Protocol V2 found.")
|
||||
path = os.path.join(OSUTIL.get_lib_dir(), METADATA_PROTOCOL)
|
||||
fileutil.write_file(path, "")
|
||||
return protocol
|
||||
|
||||
def detect_available_protocols(prob_funcs=[detect_wire_protocol, detect_metadata_protocol]):
|
||||
available_protocols = []
|
||||
for probe_func in prob_funcs:
|
||||
try:
|
||||
protocol = probeFunc()
|
||||
availableProtocols.append(protocol)
|
||||
protocol = probe_func()
|
||||
available_protocols.append(protocol)
|
||||
except ProtocolNotFound as e:
|
||||
logger.Info(str(e))
|
||||
return availableProtocols
|
||||
logger.info(str(e))
|
||||
return available_protocols
|
||||
|
||||
def DetectDefaultProtocol():
|
||||
logger.Info("Detect default protocol.")
|
||||
availableProtocols = DetectAvailableProtocols()
|
||||
return ChooseDefaultProtocol(availableProtocols)
|
||||
def detect_default_protocol():
|
||||
logger.info("Detect default protocol.")
|
||||
available_protocols = detect_available_protocols()
|
||||
return choose_default_protocol(available_protocols)
|
||||
|
||||
def ChooseDefaultProtocol(availableProtocols):
|
||||
if len(availableProtocols) > 0:
|
||||
return availableProtocols[0]
|
||||
def choose_default_protocol(protocols):
|
||||
if len(protocols) > 0:
|
||||
return protocols[0]
|
||||
else:
|
||||
raise ProtocolNotFound("No available protocol detected.")
|
||||
|
||||
def GetV1():
|
||||
path = os.path.join(OSUtil.GetLibDir(), WireProtocol)
|
||||
def get_wire_protocol():
|
||||
path = os.path.join(OSUTIL.get_lib_dir(), WIRE_PROTOCOL)
|
||||
if not os.path.isfile(path):
|
||||
raise ProtocolNotFound("Protocol V1 not found")
|
||||
|
||||
endpoint = GetWireProtocolEndpoint()
|
||||
return ProtocolV1(endpoint)
|
||||
|
||||
def GetV2():
|
||||
path = os.path.join(OSUtil.GetLibDir(), MetaDataProtocol)
|
||||
endpoint = get_wire_protocol_endpoint()
|
||||
return WIRE_PROTOCOL(endpoint)
|
||||
|
||||
def get_metadata_protocol():
|
||||
path = os.path.join(OSUTIL.get_lib_dir(), METADATA_PROTOCOL)
|
||||
if not os.path.isfile(path):
|
||||
raise ProtocolNotFound("Protocol V2 not found")
|
||||
return ProtocolV2()
|
||||
return MetadataProtocol()
|
||||
|
||||
def GetAvailableProtocols(getters=[GetV1, GetV2]):
|
||||
availableProtocols = []
|
||||
def get_available_protocols(getters=[get_wire_protocol, get_metadata_protocol]):
|
||||
available_protocols = []
|
||||
for getter in getters:
|
||||
try:
|
||||
protocol = getter()
|
||||
availableProtocols.append(protocol)
|
||||
available_protocols.append(protocol)
|
||||
except ProtocolNotFound as e:
|
||||
logger.Info(str(e))
|
||||
return availableProtocols
|
||||
logger.info(str(e))
|
||||
return available_protocols
|
||||
|
||||
class ProtocolFactory(object):
|
||||
def __init__(self):
|
||||
self._protocol = None
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def getDefaultProtocol(self):
|
||||
def get_default_protocol(self):
|
||||
if self._protocol is None:
|
||||
self._lock.acquire()
|
||||
if self._protocol is None:
|
||||
availableProtocols = GetAvailableProtocols()
|
||||
self._protocol = ChooseDefaultProtocol(availableProtocols)
|
||||
available_protocols = get_available_protocols()
|
||||
self._protocol = choose_default_protocol(available_protocols)
|
||||
self._lock.release()
|
||||
|
||||
return self._protocol
|
||||
|
||||
+500
-545
File diff suppressed because it is too large
Load Diff
@@ -21,34 +21,32 @@ import json
|
||||
import azurelinuxagent.utils.restutil as restutil
|
||||
from azurelinuxagent.protocol.common import *
|
||||
|
||||
DefaultEndpoint='169.254.169.254'
|
||||
DefaultApiVersion='2015-01-01'
|
||||
BaseUri = "https://{0}/Microsoft.Computer/{1}?$api-version={{{2}}}{3}"
|
||||
ENDPOINT='169.254.169.254'
|
||||
APIVERSION='2015-01-01'
|
||||
BASE_URI = "https://{0}/Microsoft.Computer/{1}?$api-version={{{2}}}{3}"
|
||||
|
||||
class ProtocolV2(Protocol):
|
||||
class MetadataProtocol(Protocol):
|
||||
|
||||
def __init__(self, apiVersion=DefaultApiVersion, endpoint=DefaultEndpoint):
|
||||
self.apiVersion = apiVersion
|
||||
def __init__(self, apiversion=APIVERSION, endpoint=ENDPOINT):
|
||||
self.apiversion = apiversion
|
||||
self.endpoint = endpoint
|
||||
self.identityUri = BaseUri.format(self.endpoint, "identity",
|
||||
self.apiVersion, "&expand=*")
|
||||
self.certUri = BaseUri.format(self.endpoint, "certificates",
|
||||
self.apiVersion, "&expand=*")
|
||||
self.certUri = BaseUri.format(self.endpoint, "certificates",
|
||||
self.apiVersion, "&expand=*")
|
||||
self.extUri = BaseUri.format(self.endpoint, "extensionHandlers",
|
||||
self.apiVersion, "&expand=*")
|
||||
self.provisionStatusUri = BaseUri.format(self.endpoint,
|
||||
"provisionStatus",
|
||||
self.apiVersion, "")
|
||||
self.statusUri = BaseUri.format(self.endpoint, "status",
|
||||
self.apiVersion, "")
|
||||
self.eventUri = BaseUri.format(self.endpoint, "status/telemetry",
|
||||
self.apiVersion, "")
|
||||
self.identity_uri = __base__uri.format(self.endpoint, "identity",
|
||||
self.apiversion, "&expand=*")
|
||||
self.cert_uri = __base__uri.format(self.endpoint, "certificates",
|
||||
self.apiversion, "&expand=*")
|
||||
self.ext_uri = __base__uri.format(self.endpoint, "extensionHandlers",
|
||||
self.apiversion, "&expand=*")
|
||||
self.provision_status_uri = __base__uri.format(self.endpoint,
|
||||
"provisionStatus",
|
||||
self.apiversion, "")
|
||||
self.status_uri = __base__uri.format(self.endpoint, "status",
|
||||
self.apiversion, "")
|
||||
self.event_uri = __base__uri.format(self.endpoint, "status/telemetry",
|
||||
self.apiversion, "")
|
||||
|
||||
def _getData(self, dataType, url, headers=None):
|
||||
def _get_data(self, data_type, url, headers=None):
|
||||
try:
|
||||
resp = restutil.HttpGet(url, headers)
|
||||
resp = restutil.http_get(url, headers)
|
||||
except restutil.HttpError as e:
|
||||
raise ProtocolError(str(e))
|
||||
|
||||
@@ -58,23 +56,23 @@ class ProtocolV2(Protocol):
|
||||
data = json.loads(resp.read())
|
||||
except ValueError as e:
|
||||
raise ProtocolError(str(e))
|
||||
obj = dataType()
|
||||
obj = data_type()
|
||||
set_properties(obj, data)
|
||||
return obj
|
||||
|
||||
def _putData(self, url, obj, headers=None):
|
||||
def _put_data(self, url, obj, headers=None):
|
||||
data = get_properties(obj)
|
||||
try:
|
||||
resp = restutil.HttpPut(url, json.dumps(data))
|
||||
resp = restutil.http_put(url, json.dumps(data))
|
||||
except restutil.HttpError as e:
|
||||
raise ProtocolError(str(e))
|
||||
if resp.status != httplib.OK:
|
||||
raise ProtocolError("{0} - PUT: {1}".format(resp.status, url))
|
||||
|
||||
def _postData(self, url, obj, headers=None):
|
||||
def _post_data(self, url, obj, headers=None):
|
||||
data = get_properties(obj)
|
||||
try:
|
||||
resp = restutil.HttpPost(url, json.dumps(data))
|
||||
resp = restutil.http_post(url, json.dumps(data))
|
||||
except restutil.HttpError as e:
|
||||
raise ProtocolError(str(e))
|
||||
if resp.status != httplib.CREATED:
|
||||
@@ -82,27 +80,27 @@ class ProtocolV2(Protocol):
|
||||
|
||||
def initialize(self):
|
||||
pass
|
||||
|
||||
def getVmInfo(self):
|
||||
return self._getData(VmInfo, self.identityUri)
|
||||
|
||||
def getCerts(self):
|
||||
certs = self._getData(CertList, self.certUri)
|
||||
def get_vminfo(self):
|
||||
return self._get_data(VMInfo, self.identity_uri)
|
||||
|
||||
def get_certs(self):
|
||||
certs = self._get_data(CertList, self.cert_uri)
|
||||
#TODO download pfx and convert to pem
|
||||
return certs
|
||||
|
||||
def getExtensions(self):
|
||||
return self._getData(ExtensionList, self.extUri)
|
||||
def get_extensions(self):
|
||||
return self._get_data(ExtensionList, self.ext_uri)
|
||||
|
||||
def reportProvisionStatus(self, status):
|
||||
def report_provision_status(self, status):
|
||||
validata_param('status', status, ProvisionStatus)
|
||||
self._putData(self.provisionStatusUri, status)
|
||||
|
||||
def reportStatus(self, status):
|
||||
validata_param('status', status, VMStatus)
|
||||
self._putData(self.statusUri, status)
|
||||
|
||||
def reportEvent(self, events):
|
||||
validata_param('events', events, TelemetryEventList)
|
||||
self._postData(self.eventUri, events)
|
||||
self._put_data(self.provision_status_uri, status)
|
||||
|
||||
def report_status(self, status):
|
||||
validata_param('status', status, VMStatus)
|
||||
self._put_data(self.status_uri, status)
|
||||
|
||||
def report_event(self, events):
|
||||
validata_param('events', events, TelemetryEventList)
|
||||
self._post_data(self.event_uri, events)
|
||||
|
||||
|
||||
@@ -17,163 +17,160 @@
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
|
||||
import platform
|
||||
"""
|
||||
File operation util functions
|
||||
"""
|
||||
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import pwd
|
||||
import tempfile
|
||||
import subprocess
|
||||
import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.utils.textutil as textutil
|
||||
|
||||
"""
|
||||
File operation util functions
|
||||
"""
|
||||
def GetFileContents(filepath, asbin=False, removeBom=False):
|
||||
def read_file(filepath, asbin=False, remove_bom=False):
|
||||
"""
|
||||
Read and return contents of 'filepath'.
|
||||
"""
|
||||
mode='r'
|
||||
mode = 'r'
|
||||
if asbin:
|
||||
mode+='b'
|
||||
with open(filepath, mode) as F :
|
||||
c=F.read()
|
||||
if (not asbin) and removeBom:
|
||||
c = textutil.RemoveBom(c)
|
||||
return c
|
||||
mode += 'b'
|
||||
with open(filepath, mode) as in_file:
|
||||
contents = in_file.read()
|
||||
if (not asbin) and remove_bom:
|
||||
contents = textutil.remove_bom(contents)
|
||||
return contents
|
||||
|
||||
def SetFileContents(filepath, contents):
|
||||
def write_file(filepath, contents):
|
||||
"""
|
||||
Write 'contents' to 'filepath'.
|
||||
"""
|
||||
#if type(contents) == str :
|
||||
#contents=contents.encode('latin-1', 'ignore')
|
||||
with open(filepath, "wb+") as F :
|
||||
F.write(contents)
|
||||
with open(filepath, "wb") as out_file:
|
||||
out_file.write(contents)
|
||||
|
||||
def AppendFileContents(filepath, contents):
|
||||
def append_file(filepath, contents):
|
||||
"""
|
||||
Append 'contents' to 'filepath'.
|
||||
"""
|
||||
#if type(contents) == str :
|
||||
#contents=contents.encode('latin-1')
|
||||
with open(filepath, "a+") as F :
|
||||
F.write(contents)
|
||||
with open(filepath, "a+") as out_file:
|
||||
out_file.write(contents)
|
||||
|
||||
def ReplaceFileContentsAtomic(filepath, contents):
|
||||
def replace_file(filepath, contents):
|
||||
"""
|
||||
Write 'contents' to 'filepath' by creating a temp file, and replacing original.
|
||||
Write 'contents' to 'filepath' by creating a temp file,
|
||||
and replacing original.
|
||||
"""
|
||||
handle, temp = tempfile.mkstemp(dir = os.path.dirname(filepath))
|
||||
#if type(contents) == str :
|
||||
handle, temp = tempfile.mkstemp(dir=os.path.dirname(filepath))
|
||||
#if type(contents) == str:
|
||||
#contents=contents.encode('latin-1')
|
||||
try:
|
||||
os.write(handle, contents)
|
||||
except IOError, e:
|
||||
logger.Error('Write to file {0}, Exception is {1}', filepath, e)
|
||||
except IOError as err:
|
||||
logger.error('Write to file {0}, Exception is {1}', filepath, err)
|
||||
return 1
|
||||
finally:
|
||||
os.close(handle)
|
||||
|
||||
try:
|
||||
os.rename(temp, filepath)
|
||||
except IOError, e:
|
||||
logger.Info('Rename {0} to {1}, Exception is {2}',temp, filepath, e)
|
||||
logger.Info('Remove original file and retry')
|
||||
except IOError as err:
|
||||
logger.info('Rename {0} to {1}, Exception is {2}', temp, filepath, err)
|
||||
logger.info('Remove original file and retry')
|
||||
try:
|
||||
os.remove(filepath)
|
||||
except IOError, e:
|
||||
logger.Error('Remove {0}, Exception is {1}',temp, filepath, e)
|
||||
except IOError as err:
|
||||
logger.error('Remove {0}, Exception is {1}', temp, filepath, err)
|
||||
|
||||
try:
|
||||
os.rename(temp, filepath)
|
||||
except IOError, e:
|
||||
logger.Error('Rename {0} to {1}, Exception is {2}',temp, filepath, e)
|
||||
except IOError, err:
|
||||
logger.error('Rename {0} to {1}, Exception is {2}', temp, filepath,
|
||||
err)
|
||||
return 1
|
||||
return 0
|
||||
|
||||
def GetLastPathElement(path):
|
||||
def base_name(path):
|
||||
head, tail = os.path.split(path)
|
||||
return tail
|
||||
|
||||
def GetLineStartingWith(prefix, filepath):
|
||||
def get_line_startingwith(prefix, filepath):
|
||||
"""
|
||||
Return line from 'filepath' if the line startswith 'prefix'
|
||||
"""
|
||||
for line in GetFileContents(filepath).split('\n'):
|
||||
for line in read_file(filepath).split('\n'):
|
||||
if line.startswith(prefix):
|
||||
return line
|
||||
return None
|
||||
|
||||
#End File operation util functions
|
||||
|
||||
def CreateDir(dirpath, mode=None, owner=None):
|
||||
def mkdir(dirpath, mode=None, owner=None):
|
||||
if not os.path.isdir(dirpath):
|
||||
os.makedirs(dirpath)
|
||||
if mode is not None:
|
||||
ChangeMod(dirpath, mode)
|
||||
chmod(dirpath, mode)
|
||||
if owner is not None:
|
||||
ChangeOwner(dirpath, owner)
|
||||
chowner(dirpath, owner)
|
||||
|
||||
def ChangeOwner(path, owner):
|
||||
ownerInfo = pwd.getpwnam(owner)
|
||||
os.chown(path, ownerInfo[2], ownerInfo[3])
|
||||
def chowner(path, owner):
|
||||
owner_info = pwd.getpwnam(owner)
|
||||
os.chown(path, owner_info[2], owner_info[3])
|
||||
|
||||
def ChangeMod(path, mode):
|
||||
def chmod(path, mode):
|
||||
os.chmod(path, mode)
|
||||
|
||||
def RemoveFiles(*args, **kwargs):
|
||||
def rm_files(*args):
|
||||
for path in args:
|
||||
if os.path.isfile(path):
|
||||
os.remove(path)
|
||||
|
||||
def CleanupDirs(*args, **kwargs):
|
||||
def rm_dirs(*args):
|
||||
"""
|
||||
Remove all the contents under the directry
|
||||
"""
|
||||
for dirName in args:
|
||||
if os.path.isdir(dirName):
|
||||
for item in os.listdir(dirName):
|
||||
path = os.path.join(dirName, item)
|
||||
for dir_name in args:
|
||||
if os.path.isdir(dir_name):
|
||||
for item in os.listdir(dir_name):
|
||||
path = os.path.join(dir_name, item)
|
||||
if os.path.isfile(path):
|
||||
os.remove(path)
|
||||
elif os.path.isdir(path):
|
||||
shutil.rmtree(path)
|
||||
|
||||
def UpdateConfigFile(path, lineStart, val, chk_err=False):
|
||||
config = []
|
||||
def update_conf_file(path, line_start, val, chk_err=False):
|
||||
conf = []
|
||||
if not os.path.isfile(path) and chk_err:
|
||||
raise Exception("Can't find config file:{0}".format(path))
|
||||
config = GetFileContents(path).split('\n')
|
||||
config = filter(lambda x : not x.startswith(lineStart), config)
|
||||
config.append(val)
|
||||
ReplaceFileContentsAtomic(path, '\n'.join(config))
|
||||
conf = read_file(path).split('\n')
|
||||
conf = filter(lambda x: not x.startswith(line_start), conf)
|
||||
conf.append(val)
|
||||
replace_file(path, '\n'.join(conf))
|
||||
|
||||
def SearchForFile(dirName, fileName):
|
||||
for root, dirs, files in os.walk(dirName):
|
||||
for f in files:
|
||||
if f == fileName:
|
||||
return os.path.join(root, f)
|
||||
def search_file(target_dir_name, target_file_name):
|
||||
for root, dirs, files in os.walk(target_dir_name):
|
||||
for file_name in files:
|
||||
if file_name == target_file_name:
|
||||
return os.path.join(root, file_name)
|
||||
return None
|
||||
|
||||
def ChangeTreeMod(path, mode):
|
||||
def chmod_tree(path, mode):
|
||||
for root, dirs, files in os.walk(path):
|
||||
for f in files:
|
||||
os.chmod(os.path.join(root, f), mode)
|
||||
for file_name in files:
|
||||
os.chmod(os.path.join(root, file_name), mode)
|
||||
|
||||
def FindStringInFile(fname, matchs):
|
||||
def findstr_in_file(file_path, pattern_str):
|
||||
"""
|
||||
Return match object if found in file.
|
||||
"""
|
||||
try:
|
||||
ms=re.compile(matchs)
|
||||
for l in (open(fname,'r')).readlines():
|
||||
m=re.search(ms,l)
|
||||
if m:
|
||||
return m
|
||||
pattern = re.compile(pattern_str)
|
||||
for line in (open(file_path, 'r')).readlines():
|
||||
match = re.search(pattern, line)
|
||||
if match:
|
||||
return match
|
||||
except:
|
||||
raise
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@@ -23,5 +23,5 @@ Load OSUtil implementation from azurelinuxagent.distro
|
||||
from azurelinuxagent.distro.default.osutil import OSUtilError
|
||||
import azurelinuxagent.distro.loader as loader
|
||||
|
||||
OSUtil = loader.GetOSUtil()
|
||||
OSUTIL = loader.get_osutil()
|
||||
|
||||
|
||||
@@ -29,54 +29,54 @@ from urlparse import urlparse
|
||||
"""
|
||||
REST api util functions
|
||||
"""
|
||||
__RetryWaitingInterval=10
|
||||
RETRY_WAITING_INTERVAL = 10
|
||||
|
||||
class HttpError(Exception):
|
||||
pass
|
||||
|
||||
def _ParseUrl(url):
|
||||
def _parse_url(url):
|
||||
o = urlparse(url)
|
||||
relativeUrl = o.path
|
||||
rel_uri = o.path
|
||||
if o.fragment:
|
||||
relativeUrl = "{0}#{1}".format(relativeUrl, o.fragment)
|
||||
rel_uri = "{0}#{1}".format(rel_uri, o.fragment)
|
||||
if o.query:
|
||||
relativeUrl = "{0}?{1}".format(relativeUrl, o.query)
|
||||
rel_uri = "{0}?{1}".format(rel_uri, o.query)
|
||||
secure = False
|
||||
if o.scheme.lower() == "https":
|
||||
secure = True
|
||||
return o.hostname, o.port, secure, relativeUrl
|
||||
return o.hostname, o.port, secure, rel_uri
|
||||
|
||||
def GetHttpProxy():
|
||||
def get_http_proxy():
|
||||
"""
|
||||
Get http_proxy and https_proxy from environment variables.
|
||||
Username and password is not supported now.
|
||||
"""
|
||||
host = conf.Get("HttpProxy.Host", None)
|
||||
port = conf.Get("HttpProxy.Port", None)
|
||||
return (host, port)
|
||||
host = conf.get("HttpProxy.Host", None)
|
||||
port = conf.get("HttpProxy.Port", None)
|
||||
return (host, port)
|
||||
|
||||
def _HttpRequest(method, host, relativeUrl, port=None, data=None, secure=False,
|
||||
headers=None, proxyHost=None, proxyPort=None):
|
||||
def _http_request(method, host, rel_uri, port=None, data=None, secure=False,
|
||||
headers=None, proxy_host=None, proxy_port=None):
|
||||
url, conn = None, None
|
||||
if secure:
|
||||
port = 443 if port is None else port
|
||||
if proxyHost is not None and proxyPort is not None:
|
||||
conn = httplib.HTTPSConnection(proxyHost, proxyPort)
|
||||
if proxy_host is not None and proxy_port is not None:
|
||||
conn = httplib.HTTPSConnection(proxy_host, proxy_port)
|
||||
conn.set_tunnel(host, port)
|
||||
#If proxy is used, full url is needed.
|
||||
url = "https://{0}:{1}{2}".format(host, port, relativeUrl)
|
||||
url = "https://{0}:{1}{2}".format(host, port, rel_uri)
|
||||
else:
|
||||
conn = httplib.HTTPSConnection(host, port)
|
||||
url = relativeUrl
|
||||
url = rel_uri
|
||||
else:
|
||||
port = 80 if port is None else port
|
||||
if proxyHost is not None and proxyPort is not None:
|
||||
conn = httplib.HTTPConnection(proxyHost, proxyPort)
|
||||
if proxy_host is not None and proxy_port is not None:
|
||||
conn = httplib.HTTPConnection(proxy_host, proxy_port)
|
||||
#If proxy is used, full url is needed.
|
||||
url = "http://{0}:{1}{2}".format(host, port, relativeUrl)
|
||||
url = "http://{0}:{1}{2}".format(host, port, rel_uri)
|
||||
else:
|
||||
conn = httplib.HTTPConnection(host, port)
|
||||
url = relativeUrl
|
||||
url = rel_uri
|
||||
if headers == None:
|
||||
conn.request(method, url, data)
|
||||
else:
|
||||
@@ -84,65 +84,65 @@ def _HttpRequest(method, host, relativeUrl, port=None, data=None, secure=False,
|
||||
resp = conn.getresponse()
|
||||
return resp
|
||||
|
||||
def HttpRequest(method, url, data, headers=None, maxRetry=3, chkProxy=False):
|
||||
def http_request(method, url, data, headers=None, max_retry=3, chk_proxy=False):
|
||||
"""
|
||||
Sending http request to server
|
||||
On error, sleep 10 and retry maxRetry times.
|
||||
On error, sleep 10 and retry max_retry times.
|
||||
"""
|
||||
logger.Verbose("HTTP Req: {0} {1}", method, url)
|
||||
logger.Verbose(" Data={0}", data)
|
||||
logger.Verbose(" Header={0}", headers)
|
||||
host, port, secure, relativeUrl = _ParseUrl(url)
|
||||
logger.verb("HTTP Req: {0} {1}", method, url)
|
||||
logger.verb(" Data={0}", data)
|
||||
logger.verb(" Header={0}", headers)
|
||||
host, port, secure, rel_uri = _parse_url(url)
|
||||
|
||||
#Check proxy
|
||||
proxyHost, proxyPort = (None, None)
|
||||
if chkProxy:
|
||||
proxyHost, proxyPort = GetHttpProxy()
|
||||
proxy_host, proxy_port = (None, None)
|
||||
if chk_proxy:
|
||||
proxy_host, proxy_port = get_http_proxy()
|
||||
|
||||
#If httplib module is not built with ssl support. Fallback to http
|
||||
if secure and not hasattr(httplib, "HTTPSConnection"):
|
||||
logger.Warn("httplib is not built with ssl support")
|
||||
logger.warn("httplib is not built with ssl support")
|
||||
secure = False
|
||||
|
||||
|
||||
#If httplib module doesn't support https tunnelling. Fallback to http
|
||||
if secure and \
|
||||
proxyHost is not None and \
|
||||
proxyPort is not None and \
|
||||
proxy_host is not None and \
|
||||
proxy_port is not None and \
|
||||
not hasattr(httplib.HTTPSConnection, "set_tunnel"):
|
||||
logger.Warn("httplib doesn't support https tunnelling(new in python 2.7)")
|
||||
logger.warn("httplib doesn't support https tunnelling(new in python 2.7)")
|
||||
secure = False
|
||||
|
||||
for retry in range(0, maxRetry):
|
||||
for retry in range(0, max_retry):
|
||||
try:
|
||||
resp = _HttpRequest(method, host, relativeUrl, port, data,
|
||||
secure, headers, proxyHost, proxyPort)
|
||||
logger.Verbose("HTTP Resp: Status={0}", resp.status)
|
||||
logger.Verbose(" Header={0}", resp.getheaders())
|
||||
resp = _http_request(method, host, rel_uri, port=port, data=data, secure=secure,
|
||||
headers=headers, proxy_host=proxy_host, proxy_port=proxy_port)
|
||||
logger.verb("HTTP Resp: Status={0}", resp.status)
|
||||
logger.verb(" Header={0}", resp.getheaders())
|
||||
return resp
|
||||
except httplib.HTTPException as e:
|
||||
logger.Warn('HTTPException {0}, args:{1}', e, repr(e.args))
|
||||
logger.warn('HTTPException {0}, args:{1}', e, repr(e.args))
|
||||
except IOError as e:
|
||||
logger.Warn('Socket IOError {0}, args:{1}', e, repr(e.args))
|
||||
logger.warn('Socket IOError {0}, args:{1}', e, repr(e.args))
|
||||
|
||||
if retry < maxRetry - 1:
|
||||
logger.Info("Retry={0}, {1} {2}", retry, method, url)
|
||||
time.sleep(__RetryWaitingInterval)
|
||||
if retry < max_retry - 1:
|
||||
logger.info("Retry={0}, {1} {2}", retry, method, url)
|
||||
time.sleep(RETRY_WAITING_INTERVAL)
|
||||
|
||||
raise HttpError("HTTP Err: {0} {1}".format(method, url))
|
||||
|
||||
def HttpGet(url, headers=None, maxRetry=3, chkProxy=False):
|
||||
return HttpRequest("GET", url, None, headers, maxRetry, chkProxy)
|
||||
|
||||
def HttpHead(url, headers=None, maxRetry=3, chkProxy=False):
|
||||
return HttpRequest("HEAD", url, None, headers, maxRetry, chkProxy)
|
||||
|
||||
def HttpPost(url, data, headers=None, maxRetry=3, chkProxy=False):
|
||||
return HttpRequest("POST", url, data, headers, maxRetry, chkProxy)
|
||||
def http_get(url, headers=None, max_retry=3, chk_proxy=False):
|
||||
return http_request("GET", url, data=None, headers=headers, max_retry=max_retry, chk_proxy=chk_proxy)
|
||||
|
||||
def HttpPut(url, data, headers=None, maxRetry=3, chkProxy=False):
|
||||
return HttpRequest("PUT", url, data, headers, maxRetry, chkProxy)
|
||||
def http_head(url, headers=None, max_retry=3, chk_proxy=False):
|
||||
return http_request("HEAD", url, None, headers=headers, max_retry=max_retry, chk_proxy=chk_proxy)
|
||||
|
||||
def HttpDelete(url, headers=None, maxRetry=3, chkProxy=False):
|
||||
return HttpRequest("DELETE", url, None, headers, maxRetry, chkProxy)
|
||||
def http_post(url, data, headers=None, max_retry=3, chk_proxy=False):
|
||||
return http_request("POST", url, data, headers=headers, max_retry=max_retry, chk_proxy=chk_proxy)
|
||||
|
||||
def http_put(url, data, headers=None, max_retry=3, chk_proxy=False):
|
||||
return http_request("PUT", url, data, headers=headers, max_retry=max_retry, chk_proxy=chk_proxy)
|
||||
|
||||
def http_delete(url, headers=None, max_retry=3, chk_proxy=False):
|
||||
return http_request("DELETE", url, None, headers=headers, max_retry=max_retry, chk_proxy=chk_proxy)
|
||||
|
||||
#End REST api util functions
|
||||
|
||||
@@ -50,59 +50,35 @@ if not hasattr(subprocess,'check_output'):
|
||||
|
||||
subprocess.check_output=check_output
|
||||
subprocess.CalledProcessError=CalledProcessError
|
||||
|
||||
|
||||
|
||||
"""
|
||||
Shell command util functions
|
||||
"""
|
||||
def Run(cmd, chk_err=True):
|
||||
def run(cmd, chk_err=True):
|
||||
"""
|
||||
Calls RunGetOutput on 'cmd', returning only the return code.
|
||||
Calls run_get_output on 'cmd', returning only the return code.
|
||||
If chk_err=True then errors will be reported in the log.
|
||||
If chk_err=False then errors will be suppressed from the log.
|
||||
"""
|
||||
retcode,out=RunGetOutput(cmd,chk_err)
|
||||
retcode,out=run_get_output(cmd,chk_err)
|
||||
return retcode
|
||||
|
||||
def RunGetOutput(cmd, chk_err=True):
|
||||
def run_get_output(cmd, chk_err=True):
|
||||
"""
|
||||
Wrapper for subprocess.check_output.
|
||||
Execute 'cmd'. Returns return code and STDOUT, trapping expected exceptions.
|
||||
Reports exceptions to Error if chk_err parameter is True
|
||||
"""
|
||||
logger.Verbose("Run cmd '{0}'", cmd)
|
||||
try:
|
||||
logger.verb("run cmd '{0}'", cmd)
|
||||
try:
|
||||
output=subprocess.check_output(cmd,stderr=subprocess.STDOUT,shell=True)
|
||||
except subprocess.CalledProcessError,e :
|
||||
if chk_err :
|
||||
logger.Error("Run cmd '{0}' failed", e.cmd)
|
||||
logger.Error("Error Code:{0}", e.returncode)
|
||||
logger.Error("Result:{0}", e.output[:-1].decode('latin-1'))
|
||||
logger.error("run cmd '{0}' failed", e.cmd)
|
||||
logger.error("Error Code:{0}", e.returncode)
|
||||
logger.error("Result:{0}", e.output[:-1].decode('latin-1'))
|
||||
return e.returncode, e.output.decode('latin-1')
|
||||
return 0, output
|
||||
|
||||
def RunSendStdin(cmd, input, chk_err=True):
|
||||
"""
|
||||
Wrapper for subprocess.Popen.
|
||||
Execute 'cmd', sending 'input' to STDIN of 'cmd'.
|
||||
Returns return code and STDOUT, trapping expected exceptions.
|
||||
Reports exceptions to Error if chk_err parameter is True
|
||||
"""
|
||||
logger.Verbose("Run cmd '{0}'", cmd)
|
||||
try:
|
||||
me=subprocess.Popen([cmd], shell=True, stdin=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,stdout=subprocess.PIPE)
|
||||
output=me.communicate(input)
|
||||
except OSError , e :
|
||||
if chk_err :
|
||||
logger.Error("Run cmd '{0}' failed", e.cmd)
|
||||
logger.Error("Error Code:{0}", e.returncode)
|
||||
logger.Error("Result:{0}", e.output[:-1].decode('latin-1'))
|
||||
return e.returncode, e.output.decode('latin-1')
|
||||
if me.returncode is not 0 and chk_err is True:
|
||||
logger.Error("Run cmd '{0}' failed", cmd)
|
||||
logger.Error("Error Code:{0}", me.returncode)
|
||||
logger.Error("Result:{0}", output[0].decode('latin-1'))
|
||||
return me.returncode, output[0].decode('latin-1')
|
||||
|
||||
#End shell command util functions
|
||||
|
||||
@@ -15,21 +15,22 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
|
||||
import crypt
|
||||
import random
|
||||
import string
|
||||
import struct
|
||||
|
||||
def FindFirstNode(xmlDoc, selector):
|
||||
nodes = FindAllNodes(xmlDoc, selector)
|
||||
def find_first_node(xml_doc, selector):
|
||||
nodes = find_all_nodes(xml_doc, selector)
|
||||
if len(nodes) > 0:
|
||||
return nodes[0]
|
||||
|
||||
def FindAllNodes(xmlDoc, selector):
|
||||
nodes = xmlDoc.findall(selector)
|
||||
def find_all_nodes(xml_doc, selector):
|
||||
nodes = xml_doc.findall(selector)
|
||||
return nodes
|
||||
|
||||
def GetNodeTextData(a):
|
||||
def get_node_text(a):
|
||||
"""
|
||||
Filter non-text nodes from DOM tree
|
||||
"""
|
||||
@@ -37,54 +38,54 @@ def GetNodeTextData(a):
|
||||
if b.nodeType == b.TEXT_NODE:
|
||||
return b.data
|
||||
|
||||
def Unpack(buf, offset, range):
|
||||
def unpack(buf, offset, range):
|
||||
"""
|
||||
Unpack bytes into python values.
|
||||
"""
|
||||
result = 0
|
||||
for i in range:
|
||||
result = (result << 8) | Ord(buf[offset + i])
|
||||
result = (result << 8) | str_to_ord(buf[offset + i])
|
||||
return result
|
||||
|
||||
def UnpackLittleEndian(buf, offset, length):
|
||||
def unpack_little_endian(buf, offset, length):
|
||||
"""
|
||||
Unpack little endian bytes into python values.
|
||||
"""
|
||||
return Unpack(buf, offset, list(range(length - 1, -1, -1)))
|
||||
return unpack(buf, offset, list(range(length - 1, -1, -1)))
|
||||
|
||||
def UnpackBigEndian(buf, offset, length):
|
||||
def unpack_big_endian(buf, offset, length):
|
||||
"""
|
||||
Unpack big endian bytes into python values.
|
||||
"""
|
||||
return Unpack(buf, offset, list(range(0, length)))
|
||||
return unpack(buf, offset, list(range(0, length)))
|
||||
|
||||
def HexDump3(buf, offset, length):
|
||||
def hex_dump3(buf, offset, length):
|
||||
"""
|
||||
Dump range of buf in formatted hex.
|
||||
"""
|
||||
return ''.join(['%02X' % Ord(char) for char in buf[offset:offset + length]])
|
||||
return ''.join(['%02X' % str_to_ord(char) for char in buf[offset:offset + length]])
|
||||
|
||||
def HexDump2(buf):
|
||||
def hex_dump2(buf):
|
||||
"""
|
||||
Dump buf in formatted hex.
|
||||
"""
|
||||
return HexDump3(buf, 0, len(buf))
|
||||
return hex_dump3(buf, 0, len(buf))
|
||||
|
||||
def IsInRangeInclusive(a, low, high):
|
||||
def is_in_range(a, low, high):
|
||||
"""
|
||||
Return True if 'a' in 'low' <= a >= 'high'
|
||||
"""
|
||||
return (a >= low and a <= high)
|
||||
|
||||
def IsPrintable(ch):
|
||||
def is_printable(ch):
|
||||
"""
|
||||
Return True if character is displayable.
|
||||
"""
|
||||
return (IsInRangeInclusive(ch, Ord('A'), Ord('Z'))
|
||||
or IsInRangeInclusive(ch, Ord('a'), Ord('z'))
|
||||
or IsInRangeInclusive(ch, Ord('0'), Ord('9')))
|
||||
return (is_in_range(ch, str_to_ord('A'), str_to_ord('Z'))
|
||||
or is_in_range(ch, str_to_ord('a'), str_to_ord('z'))
|
||||
or is_in_range(ch, str_to_ord('0'), str_to_ord('9')))
|
||||
|
||||
def HexDump(buffer, size):
|
||||
def hex_dump(buffer, size):
|
||||
"""
|
||||
Return Hex formated dump of a 'buffer' of 'size'.
|
||||
"""
|
||||
@@ -111,16 +112,16 @@ def HexDump(buffer, size):
|
||||
for j in range(i - (i % 16), i + 1):
|
||||
byte=buffer[j]
|
||||
if type(byte) == str:
|
||||
byte = ord(byte.decode('latin1'))
|
||||
byte = str_to_ord(byte.decode('latin1'))
|
||||
k = '.'
|
||||
if IsPrintable(byte):
|
||||
if is_printable(byte):
|
||||
k = chr(byte)
|
||||
result += k
|
||||
if (i + 1) != size:
|
||||
result += "\n"
|
||||
return result
|
||||
|
||||
def Ord(a):
|
||||
def str_to_ord(a):
|
||||
"""
|
||||
Allows indexing into a string or an array of integers transparently.
|
||||
Generic utility function.
|
||||
@@ -129,25 +130,25 @@ def Ord(a):
|
||||
a = ord(a)
|
||||
return a
|
||||
|
||||
def CompareBytes(a, b, start, length):
|
||||
def compare_bytes(a, b, start, length):
|
||||
for offset in range(start, start + length):
|
||||
if Ord(a[offset]) != Ord(b[offset]):
|
||||
if str_to_ord(a[offset]) != str_to_ord(b[offset]):
|
||||
return False
|
||||
return True
|
||||
|
||||
def IntegerToIpAddressV4String(a):
|
||||
def int_to_ip4_addr(a):
|
||||
"""
|
||||
Build DHCP request string.
|
||||
"""
|
||||
return "%u.%u.%u.%u" % ((a >> 24) & 0xFF,
|
||||
(a >> 16) & 0xFF,
|
||||
(a >> 8) & 0xFF,
|
||||
return "%u.%u.%u.%u" % ((a >> 24) & 0xFF,
|
||||
(a >> 16) & 0xFF,
|
||||
(a >> 8) & 0xFF,
|
||||
(a) & 0xFF)
|
||||
|
||||
def Ascii(val):
|
||||
def ascii(val):
|
||||
uni = None
|
||||
if type(val) == str:
|
||||
uni = unicode(val, 'utf-8', errors='ignore')
|
||||
uni = unicode(val, 'utf-8', errors='ignore')
|
||||
else:
|
||||
uni = unicode(val)
|
||||
if uni is None:
|
||||
@@ -155,7 +156,7 @@ def Ascii(val):
|
||||
else:
|
||||
return uni.encode('ascii', 'backslashreplace')
|
||||
|
||||
def HexStringToByteArray(a):
|
||||
def hexstr_to_bytearray(a):
|
||||
"""
|
||||
Return hex string packed into a binary struct.
|
||||
"""
|
||||
@@ -164,7 +165,7 @@ def HexStringToByteArray(a):
|
||||
b += struct.pack("B", int(a[c * 2:c * 2 + 2], 16))
|
||||
return b
|
||||
|
||||
def SetSshConfig(config, name, val):
|
||||
def set_ssh_config(config, name, val):
|
||||
notfound = True
|
||||
for i in range(0, len(config)):
|
||||
if config[i].startswith(name):
|
||||
@@ -177,20 +178,20 @@ def SetSshConfig(config, name, val):
|
||||
config.insert(i, "{0} {1}".format(name, val))
|
||||
return config
|
||||
|
||||
def RemoveBom(c):
|
||||
if ord(c[0]) > 128 and ord(c[1]) > 128 and ord(c[2]) > 128:
|
||||
def remove_bom(c):
|
||||
if str_to_ord(c[0]) > 128 and str_to_ord(c[1]) > 128 and str_to_ord(c[2]) > 128:
|
||||
c = c[3:]
|
||||
return c
|
||||
|
||||
def GetPasswordHash(password, useSalt, saltType, saltLength):
|
||||
salt="$6$"
|
||||
if useSalt:
|
||||
def gen_password_hash(password, use_salt, salt_type, salt_len):
|
||||
salt="$6$"
|
||||
if use_salt:
|
||||
collection = string.ascii_letters + string.digits
|
||||
salt = ''.join(random.choice(collection) for _ in range(saltLength))
|
||||
salt = "${0}${1}".format(saltType, salt)
|
||||
salt = ''.join(random.choice(collection) for _ in range(salt_len))
|
||||
salt = "${0}${1}".format(salt_type, salt)
|
||||
return crypt.crypt(password, salt)
|
||||
|
||||
def NumberToBytes(i):
|
||||
def num_to_bytes(i):
|
||||
"""
|
||||
Pack number into bytes. Retun as string.
|
||||
"""
|
||||
@@ -201,7 +202,7 @@ def NumberToBytes(i):
|
||||
result.reverse()
|
||||
return ''.join(result)
|
||||
|
||||
def BitsToString(a):
|
||||
def bits_to_str(a):
|
||||
"""
|
||||
Return string representation of bits in a.
|
||||
"""
|
||||
|
||||
+627
-214
File diff suppressed because it is too large
Load Diff
@@ -18,11 +18,11 @@
|
||||
#
|
||||
|
||||
import os
|
||||
from azurelinuxagent.metadata import GuestAgentName, GuestAgentVersion, \
|
||||
GuestAgentDescription, \
|
||||
DistroName, DistroVersion, DistroFullName
|
||||
from azurelinuxagent.metadata import AGENT_NAME, AGENT_VERSION, \
|
||||
agent_description, \
|
||||
DISTRO_NAME, DISTRO_VERSION, DISTRO_FULL_NAME
|
||||
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
import azurelinuxagent.agent as agent
|
||||
import setuptools
|
||||
from setuptools import find_packages
|
||||
@@ -103,9 +103,9 @@ class install(_install):
|
||||
|
||||
def initialize_options(self):
|
||||
_install.initialize_options(self)
|
||||
self.lnx_distro = DistroName
|
||||
self.lnx_distro_version = DistroVersion
|
||||
self.lnx_distro_fullname = DistroFullName
|
||||
self.lnx_distro = DISTRO_NAME
|
||||
self.lnx_distro_version = DISTRO_VERSION
|
||||
self.lnx_distro_fullname = DISTRO_FULL_NAME
|
||||
self.register_service = False
|
||||
|
||||
def finalize_options(self):
|
||||
@@ -118,11 +118,11 @@ class install(_install):
|
||||
def run(self):
|
||||
_install.run(self)
|
||||
if self.register_service:
|
||||
agent.RegisterService()
|
||||
agent.register_service()
|
||||
|
||||
setuptools.setup(name=GuestAgentName,
|
||||
version=GuestAgentVersion,
|
||||
long_description=GuestAgentDescription,
|
||||
setuptools.setup(name=AGENT_NAME,
|
||||
version=AGENT_VERSION,
|
||||
long_description=agent_description,
|
||||
author= 'Yue Zhang, Stephen Zarkos, Eric Gable',
|
||||
author_email = 'walinuxagent@microsoft.com',
|
||||
platforms = 'Linux',
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
# Copyright 2014 Microsoft Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
# Implements parts of RFC 2131, 1541, 1497 and
|
||||
# http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx
|
||||
# http://msdn.microsoft.com/en-us/library/cc227259%28PROT.13%29.aspx
|
||||
|
||||
import env
|
||||
import tests.tools as tools
|
||||
from tools import *
|
||||
import uuid
|
||||
import shutil
|
||||
import unittest
|
||||
import os
|
||||
import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
import azurelinuxagent.utils.textutil as textutil
|
||||
from azurelinuxagent.utils.osutil import CurrOSInfo, CurrOS
|
||||
import test
|
||||
|
||||
"""
|
||||
OS related test. Need to run with root privilege.
|
||||
|
||||
CAUSION: during the test, user account and system config may be changed
|
||||
"""
|
||||
class TestUserOperation(unittest.TestCase):
|
||||
def test_sysuser(self):
|
||||
self.assertTrue(CurrOS.IsSysUser('root'))
|
||||
|
||||
def test_update_user_account(self):
|
||||
userName="nobodywillusethisname"
|
||||
shellutil.Run('userdel -f -r ' + userName)
|
||||
self.assertFalse(tools.simple_file_grep('/etc/passwd', userName))
|
||||
self.assertFalse(os.path.isdir(os.path.join(CurrOS.GetHome(), userName)))
|
||||
|
||||
CurrOS.UpdateUserAccount(userName, "User@123")
|
||||
self.assertTrue(tools.simple_file_grep('/etc/passwd', userName))
|
||||
self.assertTrue(tools.simple_file_grep('/etc/sudoers.d/waagent',
|
||||
userName))
|
||||
self.assertTrue(os.path.isdir(os.path.join(CurrOS.GetHome(), userName)))
|
||||
|
||||
class TestSshOperation(unittest.TestCase):
|
||||
def _setUp(self):
|
||||
logger.AddLoggerAppender(logger.AppenderConfig({
|
||||
"type":"CONSOLE",
|
||||
"level":"VERBOSE",
|
||||
"console_path":"/dev/stdout"
|
||||
}))
|
||||
|
||||
def test_regen_ssh_host_key(self):
|
||||
oldKey = fileutil.GetFileContents('/etc/ssh/ssh_host_rsa_key')
|
||||
CurrOS.RegenerateSshHostkey('rsa')
|
||||
newKey = fileutil.GetFileContents('/etc/ssh/ssh_host_rsa_key')
|
||||
self.assertNotEquals(oldKey, newKey)
|
||||
|
||||
#TODO test dvd mount
|
||||
#TODO test set scsi
|
||||
#TODO test set/publish hostname
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -1,38 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# Windows Azure Linux Agent
|
||||
#
|
||||
# Copyright 2014 Microsoft Corporation
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# Requires Python 2.4+ and Openssl 1.0+
|
||||
#
|
||||
# Implements parts of RFC 2131, 1541, 1497 and
|
||||
# http://msdn.microsoft.com/en-us/library/cc227282%28PROT.10%29.aspx
|
||||
# http://msdn.microsoft.com/en-us/library/cc227259%28PROT.13%29.aspx
|
||||
#
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
test_root = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.dirname(test_root)
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
import azurelinuxagent.agent as agent
|
||||
|
||||
if __name__ == '__main__':
|
||||
agent.Main()
|
||||
|
||||
|
||||
+2
-2
@@ -29,14 +29,14 @@ import azurelinuxagent.agent as agent
|
||||
|
||||
class TestAgent(unittest.TestCase):
|
||||
def test_parse_args(self):
|
||||
cmd, force, verbose = agent.ParseArgs(["deprovision+user",
|
||||
cmd, force, verbose = agent.parse_args(["deprovision+user",
|
||||
"-force",
|
||||
"/verbose"])
|
||||
self.assertEquals("deprovision+user", cmd)
|
||||
self.assertTrue(force)
|
||||
self.assertTrue(verbose)
|
||||
|
||||
cmd, force, verbose = agent.ParseArgs(["wrong cmd"])
|
||||
cmd, force, verbose = agent.parse_args(["wrong cmd"])
|
||||
self.assertEquals("help", cmd)
|
||||
self.assertFalse(force)
|
||||
self.assertFalse(verbose)
|
||||
|
||||
+10
-11
@@ -27,7 +27,7 @@ import json
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
import azurelinuxagent.protocol.v1 as v1
|
||||
|
||||
CertificatesSample="""\
|
||||
certs_sample="""\
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<CertificateFile xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="certificates10.xsd">
|
||||
<Version>2012-11-30</Version>
|
||||
@@ -111,7 +111,7 @@ h+249Wj0Bw==
|
||||
</CertificateFile>
|
||||
"""
|
||||
|
||||
TransportCert="""\
|
||||
transport_cert="""\
|
||||
-----BEGIN CERTIFICATE-----
|
||||
MIIDBzCCAe+gAwIBAgIJANujJuVt5eC8MA0GCSqGSIb3DQEBCwUAMBkxFzAVBgNV
|
||||
BAMMDkxpbnV4VHJhbnNwb3J0MCAXDTE0MTAyNDA3MjgwN1oYDzIxMDQwNzEyMDcy
|
||||
@@ -133,7 +133,7 @@ DsfY6XGSEIhZnA4=
|
||||
-----END CERTIFICATE-----
|
||||
"""
|
||||
|
||||
TransportPrivate="""\
|
||||
transport_private="""\
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDT3CQJHeleTXqI
|
||||
EioyHk1zlguccyk8riT7eGGTFUqLFNJSOmtMB6foNnz9ts61bMFvkCc+gFVqIyBK
|
||||
@@ -180,19 +180,18 @@ class TestCertificates(unittest.TestCase):
|
||||
os.remove(crt2)
|
||||
if os.path.isfile(prv2):
|
||||
os.remove(prv2)
|
||||
fileutil.SetFileContents(os.path.join('/tmp', v1.TransportCertFile),
|
||||
TransportCert)
|
||||
fileutil.SetFileContents(os.path.join('/tmp', v1.TransportPrivateFile),
|
||||
TransportPrivate)
|
||||
config = v1.Certificates(CertificatesSample)
|
||||
fileutil.write_file(os.path.join('/tmp', "TransportCert.pem"),
|
||||
transport_cert)
|
||||
fileutil.write_file(os.path.join('/tmp', "TransportPrivate.pem"),
|
||||
transport_private)
|
||||
config = v1.Certificates(certs_sample)
|
||||
self.assertNotEquals(None, config)
|
||||
self.assertTrue(os.path.isfile(crt1))
|
||||
self.assertTrue(os.path.isfile(crt2))
|
||||
self.assertTrue(os.path.isfile(prv2))
|
||||
|
||||
certs = config.getCerts()
|
||||
self.assertNotEquals(0, len(certs.certificates))
|
||||
cert = certs.certificates[0]
|
||||
self.assertNotEquals(0, len(config.cert_list.certificates))
|
||||
cert = config.cert_list.certificates[0]
|
||||
self.assertNotEquals(None, cert.thumbprint)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
+9
-9
@@ -42,15 +42,15 @@ class TestConfiguration(unittest.TestCase):
|
||||
def test_parse_conf(self):
|
||||
config = conf.ConfigurationProvider()
|
||||
config.load(TestConf)
|
||||
self.assertEquals(True, config.getSwitch("foo.bar.switch"))
|
||||
self.assertEquals(False, config.getSwitch("foo.bar.switch2"))
|
||||
self.assertEquals(False, config.getSwitch("foo.bar.switch3"))
|
||||
self.assertEquals(True, config.getSwitch("foo.bar.switch4", True))
|
||||
self.assertEquals(True, config.get_switch("foo.bar.switch"))
|
||||
self.assertEquals(False, config.get_switch("foo.bar.switch2"))
|
||||
self.assertEquals(False, config.get_switch("foo.bar.switch3"))
|
||||
self.assertEquals(True, config.get_switch("foo.bar.switch4", True))
|
||||
self.assertEquals("foobar", config.get("foo.bar.str"))
|
||||
self.assertEquals("foobar1", config.get("foo.bar.str1", "foobar1"))
|
||||
self.assertEquals(300, config.getInt("foo.bar.int"))
|
||||
self.assertEquals(-1, config.getInt("foo.bar.int2"))
|
||||
self.assertEquals(-1, config.getInt("foo.bar.str"))
|
||||
self.assertEquals(300, config.get_int("foo.bar.int"))
|
||||
self.assertEquals(-1, config.get_int("foo.bar.int2"))
|
||||
self.assertEquals(-1, config.get_int("foo.bar.str"))
|
||||
|
||||
def test_parse_malformed_conf(self):
|
||||
with self.assertRaises(AgentConfigError) as cm:
|
||||
@@ -63,8 +63,8 @@ class TestConfiguration(unittest.TestCase):
|
||||
F.close()
|
||||
|
||||
config = conf.ConfigurationProvider()
|
||||
conf.LoadConfiguration('/tmp/test_conf', conf=config)
|
||||
self.assertEquals(True, config.getSwitch("foo.bar.switch"), False)
|
||||
conf.load_conf('/tmp/test_conf', conf=config)
|
||||
self.assertEquals(True, config.get_switch("foo.bar.switch"), False)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -21,7 +21,7 @@
|
||||
import env
|
||||
from tests.tools import *
|
||||
import unittest
|
||||
import azurelinuxagent.distro.default.deprovision as deprovisionHandler
|
||||
import azurelinuxagent.distro.default.deprovision as deprovision_handler
|
||||
|
||||
def MockAction(param):
|
||||
#print param
|
||||
@@ -30,24 +30,24 @@ def MockAction(param):
|
||||
def MockSetup(self, deluser):
|
||||
warnings = ["Print warning to console"]
|
||||
actions = [
|
||||
deprovisionHandler.DeprovisionAction(MockAction, ['Take action'])
|
||||
deprovision_handler.DeprovisionAction(MockAction, ['Take action'])
|
||||
]
|
||||
return warnings, actions
|
||||
|
||||
class TestDeprovisionHandler(unittest.TestCase):
|
||||
def test_setUp(self):
|
||||
handler = deprovisionHandler.DeprovisionHandler()
|
||||
warnings, actions = handler.setUp(False)
|
||||
def test_setup(self):
|
||||
handler = deprovision_handler.DeprovisionHandler()
|
||||
warnings, actions = handler.setup(False)
|
||||
self.assertNotEquals(None, warnings)
|
||||
self.assertNotEquals(0, len(warnings))
|
||||
self.assertNotEquals(None, actions)
|
||||
self.assertNotEquals(0, len(actions))
|
||||
self.assertEquals(deprovisionHandler.DeprovisionAction, type(actions[0]))
|
||||
self.assertEquals(deprovision_handler.DeprovisionAction, type(actions[0]))
|
||||
|
||||
|
||||
@Mockup(deprovisionHandler.DeprovisionHandler, 'setUp', MockSetup)
|
||||
@mock(deprovision_handler.DeprovisionHandler, 'setup', MockSetup)
|
||||
def test_deprovision(self):
|
||||
handler = deprovisionHandler.DeprovisionHandler()
|
||||
handler = deprovision_handler.DeprovisionHandler()
|
||||
handler.deprovision(force=True)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
+13
-13
@@ -25,34 +25,34 @@ import unittest
|
||||
import os
|
||||
import json
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
import azurelinuxagent.distro.default.dhcp as dhcpHandler
|
||||
import azurelinuxagent.distro.default.dhcp as dhcp_handler
|
||||
|
||||
SampleDhcpResponse = None
|
||||
with open(os.path.join(env.test_root, "dhcp")) as F:
|
||||
SampleDhcpResponse = F.read()
|
||||
|
||||
MockSocketSend = MockFunc('SocketSend', SampleDhcpResponse)
|
||||
MockGenTransactionId = MockFunc('GenTransactionId', "\xC6\xAA\xD1\x5D")
|
||||
MockGetMacAddress = MockFunc('GetMacAddress', "\x00\x15\x5D\x38\xAA\x38")
|
||||
mock_socket_send = MockFunc('socket_send', SampleDhcpResponse)
|
||||
mock_gen_trans_id = MockFunc('gen_trans_id', "\xC6\xAA\xD1\x5D")
|
||||
mock_get_mac_addr = MockFunc('get_mac_addr', "\x00\x15\x5D\x38\xAA\x38")
|
||||
|
||||
class TestdhcpHandler(unittest.TestCase):
|
||||
|
||||
def test_build_dhcp_req(self):
|
||||
req = dhcpHandler.BuildDhcpRequest(MockGetMacAddress())
|
||||
req = dhcp_handler.build_dhcp_request(mock_get_mac_addr())
|
||||
self.assertNotEquals(None, req)
|
||||
|
||||
@Mockup(dhcpHandler, "GenTransactionId", MockGenTransactionId)
|
||||
@Mockup(dhcpHandler, "SocketSend", MockSocketSend)
|
||||
@mock(dhcp_handler, "gen_trans_id", mock_gen_trans_id)
|
||||
@mock(dhcp_handler, "socket_send", mock_socket_send)
|
||||
def test_send_dhcp_req(self):
|
||||
req = dhcpHandler.BuildDhcpRequest(MockGetMacAddress())
|
||||
resp = dhcpHandler.SendDhcpRequest(req)
|
||||
req = dhcp_handler.build_dhcp_request(mock_get_mac_addr())
|
||||
resp = dhcp_handler.send_dhcp_request(req)
|
||||
self.assertNotEquals(None, resp)
|
||||
|
||||
@Mockup(dhcpHandler, "SocketSend", MockSocketSend)
|
||||
@Mockup(dhcpHandler, "GenTransactionId", MockGenTransactionId)
|
||||
@Mockup(dhcpHandler.OSUtil, "GetMacAddress", MockGetMacAddress)
|
||||
@mock(dhcp_handler, "socket_send", mock_socket_send)
|
||||
@mock(dhcp_handler, "gen_trans_id", mock_gen_trans_id)
|
||||
@mock(dhcp_handler.OSUtil, "get_mac_addr", mock_get_mac_addr)
|
||||
def test_handle_dhcp(self):
|
||||
dh = dhcpHandler.DhcpHandler()
|
||||
dh = dhcp_handler.DhcpHandler()
|
||||
dh.probe()
|
||||
self.assertEquals("10.62.144.1", dh.gateway)
|
||||
self.assertEquals("10.62.144.140", dh.endpoint)
|
||||
|
||||
+12
-12
@@ -21,22 +21,22 @@
|
||||
import env
|
||||
from tests.tools import *
|
||||
import unittest
|
||||
from azurelinuxagent.utils.osutil import OSUtil, OSUtilError
|
||||
from azurelinuxagent.handler import Handlers
|
||||
from azurelinuxagent.utils.osutil import OSUTIL, OSUtilError
|
||||
from azurelinuxagent.handler import HANDLERS
|
||||
import azurelinuxagent.distro.default.osutil as osutil
|
||||
|
||||
class TestDistroLoader(unittest.TestCase):
|
||||
def test_loader(self):
|
||||
self.assertNotEquals(osutil.DefaultOSUtil, type(OSUtil))
|
||||
self.assertNotEquals(None, Handlers.initHandler)
|
||||
self.assertNotEquals(None, Handlers.runHandler)
|
||||
self.assertNotEquals(None, Handlers.scvmmHandler)
|
||||
self.assertNotEquals(None, Handlers.dhcpHandler)
|
||||
self.assertNotEquals(None, Handlers.envHandler)
|
||||
self.assertNotEquals(None, Handlers.provisionHandler)
|
||||
self.assertNotEquals(None, Handlers.resourceDiskHandler)
|
||||
self.assertNotEquals(None, Handlers.envHandler)
|
||||
self.assertNotEquals(None, Handlers.deprovisionHandler)
|
||||
self.assertNotEquals(osutil.DefaultOSUtil, type(OSUTIL))
|
||||
self.assertNotEquals(None, HANDLERS.init_handler)
|
||||
self.assertNotEquals(None, HANDLERS.main_handler)
|
||||
self.assertNotEquals(None, HANDLERS.scvmm_handler)
|
||||
self.assertNotEquals(None, HANDLERS.dhcp_handler)
|
||||
self.assertNotEquals(None, HANDLERS.env_handler)
|
||||
self.assertNotEquals(None, HANDLERS.provision_handler)
|
||||
self.assertNotEquals(None, HANDLERS.resource_disk_handler)
|
||||
self.assertNotEquals(None, HANDLERS.env_handler)
|
||||
self.assertNotEquals(None, HANDLERS.deprovision_handler)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
+12
-10
@@ -22,28 +22,30 @@ import env
|
||||
from tests.tools import *
|
||||
import unittest
|
||||
import time
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
from azurelinuxagent.distro.default.env import EnvMonitor
|
||||
|
||||
class MockDhcpHandler(object):
|
||||
def configRoutes(self):
|
||||
def conf_routes(self):
|
||||
pass
|
||||
|
||||
MockDhcpProcessIdNotChange = MockFunc(retval="1234")
|
||||
def MockDhcpProcessIdChange():
|
||||
def mock_get_dhcp_pid():
|
||||
return "1234"
|
||||
|
||||
def mock_dhcp_pid_change():
|
||||
return str(time.time())
|
||||
|
||||
class TestEnvMonitor(unittest.TestCase):
|
||||
|
||||
@Mockup(OSUtil, 'GetDhcpProcessId', MockDhcpProcessIdNotChange)
|
||||
def test_dhcpProcessIdNotChange(self):
|
||||
@mock(OSUTIL, 'get_dhcp_pid', mock_get_dhcp_pid)
|
||||
def test_dhcp_pid_not_change(self):
|
||||
monitor = EnvMonitor(MockDhcpHandler())
|
||||
monitor.handleDhcpClientRestart()
|
||||
monitor.handle_dhclient_restart()
|
||||
|
||||
@Mockup(OSUtil, 'GetDhcpProcessId', MockDhcpProcessIdChange)
|
||||
def test_dhcpProcessIdChange(self):
|
||||
@mock(OSUTIL, 'get_dhcp_pid', mock_dhcp_pid_change)
|
||||
def test_dhcp_pid_change(self):
|
||||
monitor = EnvMonitor(MockDhcpHandler())
|
||||
monitor.handleDhcpClientRestart()
|
||||
monitor.handle_dhclient_restart()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
+7
-7
@@ -29,25 +29,25 @@ import azurelinuxagent.event as evt
|
||||
import azurelinuxagent.protocol as prot
|
||||
|
||||
class MockProtocol(object):
|
||||
def getInstanceMetadata(self):
|
||||
def get_instance_metadata(self):
|
||||
return prot.InstanceMetadata(deploymentName='foo', roleName='foo',
|
||||
roleInstanceId='foo', containerId='foo')
|
||||
def reportEvent(self, data): pass
|
||||
def report_event(self, data): pass
|
||||
|
||||
class TestEvent(unittest.TestCase):
|
||||
def test_save(self):
|
||||
if not os.path.exists("/tmp/events"):
|
||||
os.mkdir("/tmp/events")
|
||||
evt.AddExtensionEvent("Test", "Test", True)
|
||||
evt.add_event("Test", "Test", True)
|
||||
eventsFile = os.listdir("/tmp/events")
|
||||
self.assertNotEquals(0, len(eventsFile))
|
||||
shutil.rmtree("/tmp/events")
|
||||
|
||||
@Mockup(evt.prot.Factory, 'getDefaultProtocol',
|
||||
MockFunc(retval=MockProtocol()))
|
||||
def test_initSystemInfo(self):
|
||||
@mock(evt.prot.Factory, 'get_default_protocol',
|
||||
MockFunc(retval=MockProtocol()))
|
||||
def test_init_sys_info(self):
|
||||
monitor = evt.EventMonitor()
|
||||
self.assertNotEquals(0, len(monitor.sysInfo))
|
||||
self.assertNotEquals(0, len(monitor.sysinfo))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
+72
-72
@@ -25,12 +25,12 @@ import unittest
|
||||
import os
|
||||
import json
|
||||
import azurelinuxagent.logger as logger
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
import azurelinuxagent.protocol as prot
|
||||
import azurelinuxagent.distro.default.extension as ext
|
||||
|
||||
extensionData = {
|
||||
ext_sample_json = {
|
||||
"name":"TestExt",
|
||||
"properties":{
|
||||
"version":"2.0",
|
||||
@@ -51,10 +51,10 @@ extensionData = {
|
||||
}]
|
||||
}
|
||||
}
|
||||
extension = prot.Extension()
|
||||
prot.set_properties(extension, extensionData)
|
||||
ext_sample = prot.Extension()
|
||||
prot.set_properties(ext_sample, ext_sample_json)
|
||||
|
||||
packageListData={
|
||||
pkd_list_sample_str={
|
||||
"versions": [{
|
||||
"version": "2.0",
|
||||
"uris":[{
|
||||
@@ -67,10 +67,10 @@ packageListData={
|
||||
}]
|
||||
}]
|
||||
}
|
||||
packageList = prot.ExtensionPackageList()
|
||||
prot.set_properties(packageList, packageListData)
|
||||
pkg_list_sample = prot.ExtensionPackageList()
|
||||
prot.set_properties(pkg_list_sample, pkd_list_sample_str)
|
||||
|
||||
manJson = {
|
||||
manifest_sample_str = {
|
||||
"handlerManifest":{
|
||||
"installCommand": "echo 'install'",
|
||||
"uninstallCommand": "echo 'uninstall'",
|
||||
@@ -79,99 +79,99 @@ manJson = {
|
||||
"disableCommand": "echo 'disable'",
|
||||
}
|
||||
}
|
||||
man = ext.HandlerManifest(manJson)
|
||||
manifest_sample = ext.HandlerManifest(manifest_sample_str)
|
||||
|
||||
def MockLoadManifest(self):
|
||||
return man
|
||||
def mock_load_manifest(self):
|
||||
return manifest_sample
|
||||
|
||||
MockLaunchCommand = MockFunc()
|
||||
MockSetHandlerStatus = MockFunc()
|
||||
mock_launch_command = MockFunc()
|
||||
mock_set_handler_status = MockFunc()
|
||||
|
||||
def MockDownload(self):
|
||||
fileutil.CreateDir(self.getBaseDir())
|
||||
fileutil.SetFileContents(self.getManifestFile(), json.dumps(manJson))
|
||||
def mock_download(self):
|
||||
fileutil.mkdir(self.get_base_dir())
|
||||
fileutil.write_file(self.get_manifest_file(), json.dumps(manifest_sample_str))
|
||||
|
||||
#logger.LoggerInit("/dev/null", "/dev/stdout")
|
||||
class TestExtensions(unittest.TestCase):
|
||||
|
||||
def test_load_ext(self):
|
||||
libDir = OSUtil.GetLibDir()
|
||||
testExt1 = os.path.join(libDir, 'TestExt-1.0')
|
||||
testExt2 = os.path.join(libDir, 'TestExt-2.0')
|
||||
testExt2 = os.path.join(libDir, 'TestExt-2.1')
|
||||
for path in [testExt1, testExt2]:
|
||||
libDir = OSUTIL.get_lib_dir()
|
||||
test_ext1 = os.path.join(libDir, 'TestExt-1.0')
|
||||
test_ext2 = os.path.join(libDir, 'TestExt-2.0')
|
||||
test_ext2 = os.path.join(libDir, 'TestExt-2.1')
|
||||
for path in [test_ext1, test_ext2]:
|
||||
if not os.path.isdir(path):
|
||||
os.mkdir(path)
|
||||
testExt = ext.GetInstalledExtensionVersion('TestExt')
|
||||
self.assertEqual('2.1', testExt)
|
||||
test_ext = ext.get_installed_version('TestExt')
|
||||
self.assertEqual('2.1', test_ext)
|
||||
|
||||
def test_getters(self):
|
||||
testExt = ext.ExtensionInstance(extension, packageList,
|
||||
extension.properties.version, False)
|
||||
self.assertEqual("/tmp/TestExt-2.0", testExt.getBaseDir())
|
||||
self.assertEqual("/tmp/TestExt-2.0/status", testExt.getStatusDir())
|
||||
test_ext = ext.ExtensionInstance(ext_sample, pkg_list_sample,
|
||||
ext_sample.properties.version, False)
|
||||
self.assertEqual("/tmp/TestExt-2.0", test_ext.get_base_dir())
|
||||
self.assertEqual("/tmp/TestExt-2.0/status", test_ext.get_status_dir())
|
||||
self.assertEqual("/tmp/TestExt-2.0/status/0.status",
|
||||
testExt.getStatusFile())
|
||||
test_ext.get_status_file())
|
||||
self.assertEqual("/tmp/TestExt-2.0/config/HandlerState",
|
||||
testExt.getHandlerStateFile())
|
||||
self.assertEqual("/tmp/TestExt-2.0/config", testExt.getConfigDir())
|
||||
test_ext.get_handler_state_file())
|
||||
self.assertEqual("/tmp/TestExt-2.0/config", test_ext.get_conf_dir())
|
||||
self.assertEqual("/tmp/TestExt-2.0/config/0.settings",
|
||||
testExt.getSettingsFile())
|
||||
test_ext.get_settings_file())
|
||||
self.assertEqual("/tmp/TestExt-2.0/heartbeat.log",
|
||||
testExt.getHeartbeatFile())
|
||||
test_ext.get_heartbeat_file())
|
||||
self.assertEqual("/tmp/TestExt-2.0/HandlerManifest.json",
|
||||
testExt.getManifestFile())
|
||||
test_ext.get_manifest_file())
|
||||
self.assertEqual("/tmp/TestExt-2.0/HandlerEnvironment.json",
|
||||
testExt.getEnvironmentFile())
|
||||
self.assertEqual("/tmp/log/TestExt/2.0", testExt.getLogDir())
|
||||
test_ext.get_env_file())
|
||||
self.assertEqual("/tmp/log/TestExt/2.0", test_ext.get_log_dir())
|
||||
|
||||
testExt = ext.ExtensionInstance(extension, packageList, "2.1", False)
|
||||
self.assertEqual("/tmp/TestExt-2.1", testExt.getBaseDir())
|
||||
self.assertEqual("2.1", testExt.getTargetVersion())
|
||||
test_ext = ext.ExtensionInstance(ext_sample, pkg_list_sample, "2.1", False)
|
||||
self.assertEqual("/tmp/TestExt-2.1", test_ext.get_base_dir())
|
||||
self.assertEqual("2.1", test_ext.get_target_version())
|
||||
|
||||
@Mockup(ext.ExtensionInstance, 'loadManifest', MockLoadManifest)
|
||||
@Mockup(ext.ExtensionInstance, 'launchCommand', MockLaunchCommand)
|
||||
@Mockup(ext.ExtensionInstance, 'setHandlerStatus', MockSetHandlerStatus)
|
||||
@mock(ext.ExtensionInstance, 'load_manifest', mock_load_manifest)
|
||||
@mock(ext.ExtensionInstance, 'launch_command', mock_launch_command)
|
||||
@mock(ext.ExtensionInstance, 'set_handler_status', mock_set_handler_status)
|
||||
def test_handle_uninstall(self):
|
||||
MockLaunchCommand.args = None
|
||||
MockSetHandlerStatus.args = None
|
||||
testExt = ext.ExtensionInstance(extension, packageList,
|
||||
extension.properties.version, False)
|
||||
testExt.handleUninstall()
|
||||
self.assertEqual(None, MockLaunchCommand.args)
|
||||
self.assertEqual(None, MockSetHandlerStatus.args)
|
||||
self.assertEqual(None, testExt.getCurrOperation())
|
||||
mock_launch_command.args = None
|
||||
mock_set_handler_status.args = None
|
||||
test_ext = ext.ExtensionInstance(ext_sample, pkg_list_sample,
|
||||
ext_sample.properties.version, False)
|
||||
test_ext.handle_uninstall()
|
||||
self.assertEqual(None, mock_launch_command.args)
|
||||
self.assertEqual(None, mock_set_handler_status.args)
|
||||
self.assertEqual(None, test_ext.get_curr_op())
|
||||
|
||||
testExt = ext.ExtensionInstance(extension, packageList,
|
||||
extension.properties.version, True)
|
||||
testExt.handleUninstall()
|
||||
self.assertEqual(man.getUninstallCommand(), MockLaunchCommand.args[0])
|
||||
self.assertEqual("UnInstall", testExt.getCurrOperation())
|
||||
self.assertEqual("NotReady", MockSetHandlerStatus.args[0])
|
||||
test_ext = ext.ExtensionInstance(ext_sample, pkg_list_sample,
|
||||
ext_sample.properties.version, True)
|
||||
test_ext.handle_uninstall()
|
||||
self.assertEqual(manifest_sample.get_uninstall_command(), mock_launch_command.args[0])
|
||||
self.assertEqual("UnInstall", test_ext.get_curr_op())
|
||||
self.assertEqual("NotReady", mock_set_handler_status.args[0])
|
||||
|
||||
@Mockup(ext.ExtensionInstance, 'loadManifest', MockLoadManifest)
|
||||
@Mockup(ext.ExtensionInstance, 'launchCommand', MockLaunchCommand)
|
||||
@Mockup(ext.ExtensionInstance, 'download', MockDownload)
|
||||
@Mockup(ext.ExtensionInstance, 'getHandlerStatus', MockFunc(retval="enabled"))
|
||||
@Mockup(ext.ExtensionInstance, 'setHandlerStatus', MockSetHandlerStatus)
|
||||
@mock(ext.ExtensionInstance, 'load_manifest', mock_load_manifest)
|
||||
@mock(ext.ExtensionInstance, 'launch_command', mock_launch_command)
|
||||
@mock(ext.ExtensionInstance, 'download', mock_download)
|
||||
@mock(ext.ExtensionInstance, 'get_handler_status', MockFunc(retval="enabled"))
|
||||
@mock(ext.ExtensionInstance, 'set_handler_status', mock_set_handler_status)
|
||||
def test_handle(self):
|
||||
#Test enable
|
||||
testExt = ext.ExtensionInstance(extension, packageList,
|
||||
extension.properties.version, False)
|
||||
testExt.initLog()
|
||||
self.assertEqual(1, len(testExt.logger.appenders) - len(logger.DefaultLogger.appenders))
|
||||
testExt.handle()
|
||||
test_ext = ext.ExtensionInstance(ext_sample, pkg_list_sample,
|
||||
ext_sample.properties.version, False)
|
||||
test_ext.init_logger()
|
||||
self.assertEqual(1, len(test_ext.logger.appenders) - len(logger.default_logger.appenders))
|
||||
test_ext.handle()
|
||||
|
||||
#Test upgrade
|
||||
testExt = ext.ExtensionInstance(extension, packageList,
|
||||
extension.properties.version, False)
|
||||
testExt.initLog()
|
||||
self.assertEqual(1, len(testExt.logger.appenders) - len(logger.DefaultLogger.appenders))
|
||||
testExt.handle()
|
||||
test_ext = ext.ExtensionInstance(ext_sample, pkg_list_sample,
|
||||
ext_sample.properties.version, False)
|
||||
test_ext.init_logger()
|
||||
self.assertEqual(1, len(test_ext.logger.appenders) - len(logger.default_logger.appenders))
|
||||
test_ext.handle()
|
||||
|
||||
def test_status_convert(self):
|
||||
extStatus = json.loads('[{"status": {"status": "success", "formattedMessage": {"lang": "en-US", "message": "Script is finished"}, "operation": "Enable", "code": "0", "name": "Microsoft.OSTCExtensions.CustomScriptForLinux"}, "version": "1.0", "timestampUTC": "2015-06-27T08:34:50Z"}]')
|
||||
ext.extension_status_to_v2(extStatus[0], 0)
|
||||
ext_status = json.loads('[{"status": {"status": "success", "formattedMessage": {"lang": "en-US", "message": "Script is finished"}, "operation": "Enable", "code": "0", "name": "Microsoft.OSTCExtensions.CustomScriptForLinux"}, "version": "1.0", "timestampUTC": "2015-06-27T08:34:50Z"}]')
|
||||
ext.ext_status_to_v2(ext_status[0], 0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -26,7 +26,7 @@ import os
|
||||
import json
|
||||
import azurelinuxagent.protocol.v1 as v1
|
||||
|
||||
ExtensionsConfigSample="""\
|
||||
ext_conf_sample="""\
|
||||
<Extensions version="1.0.0.0" goalStateIncarnation="9"><GuestAgentExtension xmlns:i="http://www.w3.org/2001/XMLSchema-instance">
|
||||
<GAFamilies>
|
||||
<GAFamily>
|
||||
@@ -68,13 +68,13 @@ ExtensionsConfigSample="""\
|
||||
</Plugins>
|
||||
<PluginSettings>
|
||||
<Plugin name="OSTCExtensions.ExampleHandlerLinux" version="1.4">
|
||||
<RuntimeSettings seqNo="6">{"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3","protectedSettings":"MIICWgYJK","publicSettings":{"foo":"bar"}}}]}</RuntimeSettings>
|
||||
<runtimeSettings seqNo="6">{"runtimeSettings":[{"handlerSettings":{"protectedSettingsCertThumbprint":"4037FBF5F1F3014F99B5D6C7799E9B20E6871CB3","protectedSettings":"MIICWgYJK","publicSettings":{"foo":"bar"}}}]}</runtimeSettings>
|
||||
</Plugin>
|
||||
</PluginSettings>
|
||||
<StatusUploadBlob>https://yuezhatest.blob.core.windows.net/vhds/test-cs12.test-cs12.test-cs12.status?sr=b&sp=rw&se=9999-01-01&sk=key1&sv=2014-02-14&sig=hfRh7gzUE7sUtYwke78IOlZOrTRCYvkec4hGZ9zZzXo%3D</StatusUploadBlob></Extensions>
|
||||
"""
|
||||
|
||||
ManifestSample="""\
|
||||
manifest_sample="""\
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<PluginVersionManifest xmlns:i="http://www.w3.org/2001/XMLSchema-instance">
|
||||
<Plugins>
|
||||
@@ -113,8 +113,8 @@ EmptySettings="""\
|
||||
|
||||
class TestExtensionsConfig(unittest.TestCase):
|
||||
def test_extensions_config(self):
|
||||
config = v1.ExtensionsConfig(ExtensionsConfigSample)
|
||||
extensions = config.extList.extensions
|
||||
config = v1.ExtensionsConfig(ext_conf_sample)
|
||||
extensions = config.ext_list.extensions
|
||||
self.assertNotEquals(None, extensions)
|
||||
self.assertEquals(1, len(extensions))
|
||||
self.assertNotEquals(None, extensions[0])
|
||||
@@ -131,9 +131,9 @@ class TestExtensionsConfig(unittest.TestCase):
|
||||
self.assertEquals(json.loads('{"foo":"bar"}'),
|
||||
settings.publicSettings)
|
||||
|
||||
man = v1.ExtensionManifest(ManifestSample)
|
||||
self.assertNotEquals(None, man.packageList)
|
||||
self.assertEquals(3, len(man.packageList.versions))
|
||||
man = v1.ExtensionManifest(manifest_sample)
|
||||
self.assertNotEquals(None, man.pkg_list)
|
||||
self.assertEquals(3, len(man.pkg_list.versions))
|
||||
|
||||
def test_empty_settings(self):
|
||||
config = v1.ExtensionsConfig(EmptySettings)
|
||||
|
||||
@@ -30,36 +30,36 @@ class TestFileOperations(unittest.TestCase):
|
||||
def test_get_set_file_contents(self):
|
||||
test_file='/tmp/test_file'
|
||||
content = str(uuid.uuid4())
|
||||
fileutil.SetFileContents(test_file, content)
|
||||
fileutil.write_file(test_file, content)
|
||||
self.assertTrue(tools.simple_file_grep(test_file, content))
|
||||
self.assertEquals(content, fileutil.GetFileContents('/tmp/test_file'))
|
||||
self.assertEquals(content, fileutil.read_file('/tmp/test_file'))
|
||||
os.remove(test_file)
|
||||
|
||||
def test_append_file(self):
|
||||
test_file='/tmp/test_file2'
|
||||
content = str(uuid.uuid4())
|
||||
fileutil.AppendFileContents(test_file, content)
|
||||
fileutil.append_file(test_file, content)
|
||||
self.assertTrue(tools.simple_file_grep(test_file, content))
|
||||
os.remove(test_file)
|
||||
|
||||
def test_replace_file(self):
|
||||
test_file='/tmp/test_file3'
|
||||
contentOld = str(uuid.uuid4())
|
||||
old_content = str(uuid.uuid4())
|
||||
content = str(uuid.uuid4())
|
||||
with open(test_file, "a+") as F:
|
||||
F.write(contentOld)
|
||||
fileutil.ReplaceFileContentsAtomic(test_file, content)
|
||||
self.assertFalse(tools.simple_file_grep(test_file, contentOld))
|
||||
F.write(old_content)
|
||||
fileutil.replace_file(test_file, content)
|
||||
self.assertFalse(tools.simple_file_grep(test_file, old_content))
|
||||
self.assertTrue(tools.simple_file_grep(test_file, content))
|
||||
os.remove(test_file)
|
||||
|
||||
def test_get_last_path_element(self):
|
||||
filepath = '/tmp/abc.def'
|
||||
filename = fileutil.GetLastPathElement(filepath)
|
||||
filename = fileutil.base_name(filepath)
|
||||
self.assertEquals('abc.def', filename)
|
||||
|
||||
filepath = '/tmp/abc'
|
||||
filename = fileutil.GetLastPathElement(filepath)
|
||||
filename = fileutil.base_name(filepath)
|
||||
self.assertEquals('abc', filename)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
+10
-10
@@ -26,7 +26,7 @@ import os
|
||||
import test
|
||||
import azurelinuxagent.protocol.v1 as v1
|
||||
|
||||
GoalStateSample="""
|
||||
goal_state_sample="""
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<GoalState xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:noNamespaceSchemaLocation="goalstate10.xsd">
|
||||
<Version>2010-12-15</Version>
|
||||
@@ -58,15 +58,15 @@ GoalStateSample="""
|
||||
|
||||
class TestGoalState(unittest.TestCase):
|
||||
def test_goal_state(self):
|
||||
goalState = v1.GoalState(GoalStateSample)
|
||||
self.assertEquals('1', goalState.getIncarnation())
|
||||
self.assertNotEquals(None, goalState.getExpectedState())
|
||||
self.assertNotEquals(None, goalState.getHostingEnvUri())
|
||||
self.assertNotEquals(None, goalState.getSharedConfigUri())
|
||||
self.assertNotEquals(None, goalState.getCertificatesUri())
|
||||
self.assertNotEquals(None, goalState.getExtensionsUri())
|
||||
self.assertNotEquals(None, goalState.getRoleInstanceId())
|
||||
self.assertNotEquals(None, goalState.getContainerId())
|
||||
goal_state = v1.GoalState(goal_state_sample)
|
||||
self.assertEquals('1', goal_state.incarnation)
|
||||
self.assertNotEquals(None, goal_state.expected_state)
|
||||
self.assertNotEquals(None, goal_state.hosting_env_uri)
|
||||
self.assertNotEquals(None, goal_state.shared_conf_uri)
|
||||
self.assertNotEquals(None, goal_state.certs_uri)
|
||||
self.assertNotEquals(None, goal_state.ext_uri)
|
||||
self.assertNotEquals(None, goal_state.role_instance_id)
|
||||
self.assertNotEquals(None, goal_state.container_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -25,7 +25,7 @@ import unittest
|
||||
import os
|
||||
import azurelinuxagent.protocol.v1 as v1
|
||||
|
||||
HostingEnvSample="""
|
||||
hosting_env_sample="""
|
||||
<HostingEnvironmentConfig version="1.0.0.0" goalStateIncarnation="1">
|
||||
<StoredCertificates>
|
||||
<StoredCertificate name="Stored0Microsoft.WindowsAzure.Plugins.RemoteAccess.PasswordEncryption" certificateId="sha1:C093FA5CD3AAE057CB7C4E04532B2E16E07C26CA" storeName="My" configurationLevel="System" />
|
||||
@@ -36,7 +36,7 @@ HostingEnvSample="""
|
||||
</Deployment>
|
||||
<Incarnation number="1" instance="MachineRole_IN_0" guid="{a0faca35-52e5-4ec7-8fd1-63d2bc107d9b}" />
|
||||
<Role guid="{73d95f1c-6472-e58e-7a1a-523554e11d46}" name="MachineRole" hostingEnvironmentVersion="1" software="" softwareType="ApplicationPackage" entryPoint="" parameters="" settleTimeSeconds="10" />
|
||||
<HostingEnvironmentSettings name="full" Runtime="rd_fabric_stable.110217-1402.RuntimePackage_1.0.0.8.zip">
|
||||
<HostingEnvironmentSettings name="full" runtime="rd_fabric_stable.110217-1402.runtimePackage_1.0.0.8.zip">
|
||||
<CAS mode="full" />
|
||||
<PrivilegeLevel mode="max" />
|
||||
<AdditionalProperties><CgiHandlers></CgiHandlers></AdditionalProperties>
|
||||
@@ -53,13 +53,13 @@ HostingEnvSample="""
|
||||
"""
|
||||
|
||||
class TestHostingEvn(unittest.TestCase):
|
||||
def test_hostingenv(self):
|
||||
hostingenv = v1.HostingEnv(HostingEnvSample)
|
||||
self.assertNotEquals(None, hostingenv)
|
||||
self.assertEquals("MachineRole_IN_0", hostingenv.getVmName())
|
||||
self.assertEquals("MachineRole", hostingenv.getRoleName())
|
||||
def test_hosting_env(self):
|
||||
hosting_env = v1.HostingEnv(hosting_env_sample)
|
||||
self.assertNotEquals(None, hosting_env)
|
||||
self.assertEquals("MachineRole_IN_0", hosting_env.vm_name)
|
||||
self.assertEquals("MachineRole", hosting_env.role_name)
|
||||
self.assertEquals("db00a7755a5e4e8a8fe4b19bc3b330c3",
|
||||
hostingenv.getDeploymentName())
|
||||
hosting_env.deployment_name)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -51,7 +51,7 @@ class TestLogger(unittest.TestCase):
|
||||
|
||||
def test_file_appender(self):
|
||||
_logger = logger.Logger()
|
||||
_logger.addLoggerAppender(logger.AppenderType.FILE,
|
||||
_logger.add_appender(logger.AppenderType.FILE,
|
||||
logger.LogLevel.INFO,
|
||||
'/tmp/testlog')
|
||||
|
||||
@@ -69,14 +69,14 @@ class TestLogger(unittest.TestCase):
|
||||
|
||||
def test_log_to_non_exists_dev(self):
|
||||
_logger = logger.Logger()
|
||||
_logger.addLoggerAppender(logger.AppenderType.CONSOLE,
|
||||
_logger.add_appender(logger.AppenderType.CONSOLE,
|
||||
logger.LogLevel.INFO,
|
||||
'/dev/nonexists')
|
||||
_logger.info("something")
|
||||
|
||||
def test_log_to_non_exists_file(self):
|
||||
_logger = logger.Logger()
|
||||
_logger.addLoggerAppender(logger.AppenderType.FILE,
|
||||
_logger.add_appender(logger.AppenderType.FILE,
|
||||
logger.LogLevel.INFO,
|
||||
'/tmp/nonexists')
|
||||
_logger.info("something")
|
||||
|
||||
@@ -21,16 +21,16 @@
|
||||
import env
|
||||
from tests.tools import *
|
||||
import unittest
|
||||
from azurelinuxagent.metadata import GuestAgentName, GuestAgentVersion, \
|
||||
DistroName, DistroVersion, DistroCodeName, \
|
||||
DistroFullName
|
||||
from azurelinuxagent.metadata import AGENT_NAME, AGENT_VERSION, \
|
||||
DISTRO_NAME, DISTRO_VERSION, DISTRO_CODE_NAME, \
|
||||
DISTRO_FULL_NAME
|
||||
|
||||
class TestOSInfo(unittest.TestCase):
|
||||
def test_curr_os_info(self):
|
||||
self.assertNotEquals(None, DistroName)
|
||||
self.assertNotEquals(None, DistroVersion)
|
||||
self.assertNotEquals(None, DistroCodeName)
|
||||
self.assertNotEquals(None, DistroFullName)
|
||||
self.assertNotEquals(None, DISTRO_NAME)
|
||||
self.assertNotEquals(None, DISTRO_VERSION)
|
||||
self.assertNotEquals(None, DISTRO_CODE_NAME)
|
||||
self.assertNotEquals(None, DISTRO_FULL_NAME)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
+71
-71
@@ -28,14 +28,14 @@ import time
|
||||
import azurelinuxagent.utils.fileutil as fileutil
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
import azurelinuxagent.conf as conf
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
import test
|
||||
|
||||
class TestOSUtil(unittest.TestCase):
|
||||
def test_current_distro(self):
|
||||
self.assertNotEquals(None, OSUtil)
|
||||
self.assertNotEquals(None, OSUTIL)
|
||||
|
||||
MountlistSample="""\
|
||||
mount_list_sample="""\
|
||||
/dev/sda1 on / type ext4 (rw)
|
||||
proc on /proc type proc (rw)
|
||||
sysfs on /sys type sysfs (rw)
|
||||
@@ -48,29 +48,29 @@ none on /proc/sys/fs/binfmt_misc type binfmt_misc (rw)
|
||||
class TestCurrOS(unittest.TestCase):
|
||||
#class TestCurrOS(object):
|
||||
def test_get_paths(self):
|
||||
self.assertNotEquals(None, OSUtil.GetHome())
|
||||
self.assertNotEquals(None, OSUtil.GetLibDir())
|
||||
self.assertNotEquals(None, OSUtil.GetAgentPidPath())
|
||||
self.assertNotEquals(None, OSUtil.GetConfigurationPath())
|
||||
self.assertNotEquals(None, OSUtil.GetDvdMountPoint())
|
||||
self.assertNotEquals(None, OSUtil.GetOvfEnvPathOnDvd())
|
||||
self.assertNotEquals(None, OSUTIL.get_home())
|
||||
self.assertNotEquals(None, OSUTIL.get_lib_dir())
|
||||
self.assertNotEquals(None, OSUTIL.get_agent_pid_file_path())
|
||||
self.assertNotEquals(None, OSUTIL.get_conf_file_path())
|
||||
self.assertNotEquals(None, OSUTIL.get_dvd_mount_point())
|
||||
self.assertNotEquals(None, OSUTIL.get_ovf_env_file_path_on_dvd())
|
||||
|
||||
@Mockup(shellutil, 'RunGetOutput', MockFunc(retval=[0, '']))
|
||||
@Mockup(shellutil, 'RunSendStdin', MockFunc(retval=[0, '']))
|
||||
@Mockup(fileutil, 'SetFileContents', MockFunc())
|
||||
@Mockup(fileutil, 'AppendFileContents', MockFunc())
|
||||
@Mockup(fileutil, 'GetFileContents', MockFunc(retval=''))
|
||||
@Mockup(fileutil, 'ChangeMod', MockFunc())
|
||||
@mock(fileutil, 'write_file', MockFunc())
|
||||
@mock(fileutil, 'append_file', MockFunc())
|
||||
@mock(fileutil, 'chmod', MockFunc())
|
||||
@mock(fileutil, 'read_file', MockFunc(retval=''))
|
||||
@mock(shellutil, 'run', MockFunc())
|
||||
@mock(shellutil, 'run_get_output', MockFunc(retval=[0, '']))
|
||||
def test_update_user_account(self):
|
||||
OSUtil.UpdateUserAccount('api', 'api')
|
||||
OSUtil.DeleteAccount('api')
|
||||
OSUTIL.set_user_account('api', 'api')
|
||||
OSUTIL.del_account('api')
|
||||
|
||||
@Mockup(fileutil, 'GetFileContents', MockFunc(retval='root::::'))
|
||||
@Mockup(fileutil, 'SetFileContents', MockFunc())
|
||||
@mock(fileutil, 'read_file', MockFunc(retval='root::::'))
|
||||
@mock(fileutil, 'write_file', MockFunc())
|
||||
def test_delete_root_password(self):
|
||||
OSUtil.DeleteRootPassword()
|
||||
OSUTIL.del_root_password()
|
||||
self.assertEquals('root:*LOCK*:14600::::::',
|
||||
fileutil.SetFileContents.args[1])
|
||||
fileutil.write_file.args[1])
|
||||
|
||||
def test_cert_operation(self):
|
||||
if os.path.isfile('/tmp/test.prv'):
|
||||
@@ -81,84 +81,84 @@ class TestCurrOS(unittest.TestCase):
|
||||
os.remove('/tmp/test.crt')
|
||||
shutil.copyfile(os.path.join(env.test_root, 'test.crt'),
|
||||
'/tmp/test.crt')
|
||||
pub1 = OSUtil.GetPubKeyFromPrv('/tmp/test.prv')
|
||||
pub2 = OSUtil.GetPubKeyFromCrt('/tmp/test.crt')
|
||||
pub1 = OSUTIL.get_pubkey_from_prv('/tmp/test.prv')
|
||||
pub2 = OSUTIL.get_pubkey_from_crt('/tmp/test.crt')
|
||||
self.assertEquals(pub1, pub2)
|
||||
thumbprint = OSUtil.GetThumbprintFromCrt('/tmp/test.crt')
|
||||
thumbprint = OSUTIL.get_thumbprint_from_crt('/tmp/test.crt')
|
||||
self.assertEquals('33B0ABCE4673538650971C10F7D7397E71561F35', thumbprint)
|
||||
|
||||
def test_selinux(self):
|
||||
if not OSUtil.IsSelinuxSystem():
|
||||
if not OSUTIL.is_selinux_system():
|
||||
return
|
||||
isRunning = OSUtil.IsSelinuxRunning()
|
||||
if not OSUtil.IsSelinuxRunning():
|
||||
OSUtil.SetSelinuxEnforce(0)
|
||||
self.assertEquals(False, OSUtil.IsSelinuxRunning())
|
||||
OSUtil.SetSelinuxEnforce(1)
|
||||
self.assertEquals(True, OSUtil.IsSelinuxRunning())
|
||||
isrunning = OSUTIL.is_selinux_enforcing()
|
||||
if not OSUTIL.is_selinux_enforcing():
|
||||
OSUTIL.set_selinux_enforce(0)
|
||||
self.assertEquals(False, OSUTIL.is_selinux_enforcing())
|
||||
OSUTIL.set_selinux_enforce(1)
|
||||
self.assertEquals(True, OSUTIL.is_selinux_enforcing())
|
||||
if os.path.isfile('/tmp/abc'):
|
||||
os.remove('/tmp/abc')
|
||||
fileutil.SetFileContents('/tmp/abc', '')
|
||||
OSUtil.SetSelinuxContext('/tmp/abc','unconfined_u:object_r:ssh_home_t:s')
|
||||
OSUtil.SetSelinuxEnforce(1 if isRunning else 0)
|
||||
fileutil.write_file('/tmp/abc', '')
|
||||
OSUTIL.set_selinux_context('/tmp/abc','unconfined_u:object_r:ssh_home_t:s')
|
||||
OSUTIL.set_selinux_enforce(1 if isrunning else 0)
|
||||
|
||||
@Mockup(shellutil, 'RunGetOutput', MockFunc(retval=[0, '']))
|
||||
@Mockup(fileutil, 'SetFileContents', MockFunc())
|
||||
@mock(shellutil, 'run_get_output', MockFunc(retval=[0, '']))
|
||||
@mock(fileutil, 'write_file', MockFunc())
|
||||
def test_network_operation(self):
|
||||
OSUtil.StartNetwork()
|
||||
OSUtil.OpenPortForDhcp()
|
||||
OSUtil.GenerateTransportCert()
|
||||
mac = OSUtil.GetMacAddress()
|
||||
OSUTIL.start_network()
|
||||
OSUTIL.allow_dhcp_broadcast()
|
||||
OSUTIL.gen_transport_cert()
|
||||
mac = OSUTIL.get_mac_addr()
|
||||
self.assertNotEquals(None, mac)
|
||||
OSUtil.IsMissingDefaultRoute()
|
||||
OSUtil.SetBroadcastRouteForDhcp('api')
|
||||
OSUtil.RemoveBroadcastRouteForDhcp('api')
|
||||
OSUtil.RouteAdd('', '', '')
|
||||
OSUtil.GetDhcpProcessId()
|
||||
OSUtil.SetHostname('api')
|
||||
OSUtil.PublishHostname('api')
|
||||
OSUTIL.is_missing_default_route()
|
||||
OSUTIL.set_route_for_dhcp_broadcast('api')
|
||||
OSUTIL.remove_route_for_dhcp_broadcast('api')
|
||||
OSUTIL.route_add('', '', '')
|
||||
OSUTIL.get_dhcp_pid()
|
||||
OSUTIL.set_hostname('api')
|
||||
OSUTIL.publish_hostname('api')
|
||||
|
||||
@Mockup(OSUtil, 'GetHome', MockFunc(retval='/tmp/home'))
|
||||
@mock(OSUTIL, 'get_home', MockFunc(retval='/tmp/home'))
|
||||
def test_deploy_key(self):
|
||||
if os.path.isdir('/tmp/home'):
|
||||
shutil.rmtree('/tmp/home')
|
||||
user = shellutil.RunGetOutput('whoami')[1].strip()
|
||||
OSUtil.DeploySshKeyPair(user, 'test', '$HOME/.ssh/id_rsa')
|
||||
OSUtil.DeploySshPublicKey(user, 'test', '$HOME/.ssh/authorized_keys')
|
||||
user = shellutil.run_get_output('whoami')[1].strip()
|
||||
OSUTIL.deploy_ssh_keypair(user, 'test', '$HOME/.ssh/id_rsa')
|
||||
OSUTIL.deploy_ssh_pubkey(user, 'test', '$HOME/.ssh/authorized_keys')
|
||||
self.assertTrue(os.path.isfile('/tmp/home/.ssh/id_rsa'))
|
||||
self.assertTrue(os.path.isfile('/tmp/home/.ssh/id_rsa.pub'))
|
||||
self.assertTrue(os.path.isfile('/tmp/home/.ssh/authorized_keys'))
|
||||
|
||||
@Mockup(shellutil, 'RunGetOutput', MockFunc(retval=[0, '']))
|
||||
@Mockup(OSUtil, 'GetSshdConfigPath', MockFunc(retval='/tmp/sshd_config'))
|
||||
@mock(shellutil, 'run_get_output', MockFunc(retval=[0, '']))
|
||||
@mock(OSUTIL, 'get_sshd_conf_file_path', MockFunc(retval='/tmp/sshd_config'))
|
||||
def test_ssh_operation(self):
|
||||
shellutil.RunGetOutput.retval=[0,
|
||||
shellutil.run_get_output.retval=[0,
|
||||
'2048 f1:fe:14:66:9d:46:9a:60:8b:8c:'
|
||||
'80:43:39:1c:20:9e root@api (RSA)']
|
||||
sshdConfig = OSUtil.GetSshdConfigPath()
|
||||
self.assertEquals('/tmp/sshd_config', sshdConfig)
|
||||
if os.path.isfile(sshdConfig):
|
||||
os.remove(sshdConfig)
|
||||
shutil.copyfile(os.path.join(env.test_root, 'sshd_config'), sshdConfig)
|
||||
OSUtil.SetSshClientAliveInterval()
|
||||
OSUtil.ConfigSshd(True)
|
||||
self.assertTrue(simple_file_grep(sshdConfig,
|
||||
sshd_conf = OSUTIL.get_sshd_conf_file_path()
|
||||
self.assertEquals('/tmp/sshd_config', sshd_conf)
|
||||
if os.path.isfile(sshd_conf):
|
||||
os.remove(sshd_conf)
|
||||
shutil.copyfile(os.path.join(env.test_root, 'sshd_config'), sshd_conf)
|
||||
OSUTIL.set_ssh_client_alive_interval()
|
||||
OSUTIL.conf_sshd(True)
|
||||
self.assertTrue(simple_file_grep(sshd_conf,
|
||||
'PasswordAuthentication no'))
|
||||
self.assertTrue(simple_file_grep(sshdConfig,
|
||||
self.assertTrue(simple_file_grep(sshd_conf,
|
||||
'ChallengeResponseAuthentication no'))
|
||||
self.assertTrue(simple_file_grep(sshdConfig,
|
||||
self.assertTrue(simple_file_grep(sshd_conf,
|
||||
'ClientAliveInterval 180'))
|
||||
|
||||
@Mockup(shellutil, 'RunGetOutput', MockFunc(retval=[0, '']))
|
||||
@Mockup(OSUtil, 'GetDvdMountPoint', MockFunc(retval='/tmp/cdrom'))
|
||||
@mock(shellutil, 'run_get_output', MockFunc(retval=[0, '']))
|
||||
@mock(OSUTIL, 'get_mount_point', MockFunc(retval='/tmp/cdrom'))
|
||||
def test_mount(self):
|
||||
OSUtil.MountDvd()
|
||||
OSUtil.UmountDvd()
|
||||
mountPoint = OSUtil.GetMountPoint(MountlistSample, '/dev/sda')
|
||||
self.assertNotEquals(None, mountPoint)
|
||||
OSUTIL.mount_dvd()
|
||||
OSUTIL.umount_dvd()
|
||||
mount_point = OSUTIL.get_mount_point(mount_list_sample, '/dev/sda')
|
||||
self.assertNotEquals(None, mount_point)
|
||||
|
||||
def test_getdvd(self):
|
||||
OSUtil.GetDvdDevice()
|
||||
OSUTIL.get_dvd_device()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
+11
-11
@@ -59,17 +59,17 @@ ExtensionsConfigSample="""
|
||||
class TestOvf(unittest.TestCase):
|
||||
def test_ovf(self):
|
||||
config = ovfenv.OvfEnv(ExtensionsConfigSample)
|
||||
self.assertEquals(1, config.getMajorVersion())
|
||||
self.assertEquals(0, config.getMinorVersion())
|
||||
self.assertEquals("HostName", config.getComputerName())
|
||||
self.assertEquals("UserName", config.getUserName())
|
||||
self.assertEquals("UserPassword", config.getUserPassword())
|
||||
self.assertEquals(False, config.getDisableSshPasswordAuthentication())
|
||||
self.assertEquals("CustomData", config.getCustomData())
|
||||
self.assertNotEquals(None, config.getSshPublicKeys())
|
||||
self.assertEquals(1, len(config.getSshPublicKeys()))
|
||||
self.assertNotEquals(None, config.getSshKeyPairs())
|
||||
self.assertEquals(1, len(config.getSshKeyPairs()))
|
||||
self.assertEquals(1, config.get_major_version())
|
||||
self.assertEquals(0, config.get_minor_version())
|
||||
self.assertEquals("HostName", config.get_computer_name())
|
||||
self.assertEquals("UserName", config.get_username())
|
||||
self.assertEquals("UserPassword", config.get_user_password())
|
||||
self.assertEquals(False, config.get_disable_ssh_password_auth())
|
||||
self.assertEquals("CustomData", config.get_customdata())
|
||||
self.assertNotEquals(None, config.get_ssh_pubkeys())
|
||||
self.assertEquals(1, len(config.get_ssh_pubkeys()))
|
||||
self.assertNotEquals(None, config.get_ssh_keypairs())
|
||||
self.assertEquals(1, len(config.get_ssh_keypairs()))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -62,7 +62,7 @@ extensionDataStr = """
|
||||
|
||||
class TestProtocolContract(unittest.TestCase):
|
||||
def test_get_properties(self):
|
||||
data = get_properties(VmInfo())
|
||||
data = get_properties(VMInfo())
|
||||
data = get_properties(Cert())
|
||||
data = get_properties(ExtensionPackageList())
|
||||
data = get_properties(InstanceMetadata())
|
||||
|
||||
@@ -29,11 +29,11 @@ import azurelinuxagent.protocol.protocolFactory as protocolFactory
|
||||
class TestWireProtocolEndpoint(unittest.TestCase):
|
||||
def test_get_available_protocol(self):
|
||||
with self.assertRaises(protocol.ProtocolNotFound):
|
||||
protocol.Factory.getDefaultProtocol()
|
||||
protocol.Factory.get_default_protocol()
|
||||
|
||||
def test_get_available_protocols(self):
|
||||
mockGetV1 = MockFunc(retval="Mock protocol")
|
||||
protocols = protocolFactory.GetAvailableProtocols([mockGetV1])
|
||||
protocols = protocolFactory.get_available_protocols([mockGetV1])
|
||||
self.assertNotEquals(None, protocols)
|
||||
self.assertNotEquals(0, len(protocols))
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from tests.tools import *
|
||||
import unittest
|
||||
from azurelinuxagent.distro.redhat.osutil import RedhatOSUtil
|
||||
|
||||
TestPublicKey="""\
|
||||
test_pubkey="""\
|
||||
-----BEGIN PUBLIC KEY-----
|
||||
MIIBIDANBgkqhkiG9w0BAQEFAAOCAQ0AMIIBCAKCAQEA2wo22vf1N8NWE+5lLfit
|
||||
T7uzkfwqdw0IAoHZ0l2BtP0ajy6f835HCR3w3zLWw5ut7Xvyo26x1OMOzjo5lqtM
|
||||
@@ -35,15 +35,15 @@ fQIBIw==
|
||||
-----END PUBLIC KEY-----
|
||||
"""
|
||||
|
||||
Expected="""\
|
||||
expected_ssh_rsa_pubkey="""\
|
||||
ssh-rsa AAAAB3NzaC1yc2EAAAABIwAAAQEA2wo22vf1N8NWE+5lLfitT7uzkfwqdw0IAoHZ0l2BtP0ajy6f835HCR3w3zLWw5ut7Xvyo26x1OMOzjo5lqtMh8iyQwfHtWf6Cekxfkf+6Pca99bNuDgwRopOTOyoVgwDzJB0+slpn/sJjeGbhxJlToT8tNPLrBmnnpaMZLMIANcPQtTRCQcV/ycv+/omKXFB+zULYkN8v22o5mysoCuQfzXiJP3Mlnf+V2XMl1WAJylhOJif04K8j+G8oF5ECBIQiph4ZLQS1yTYlozPXU8k8vB6A5+UiOGxBnOQYnp42cS5d4qSQ8LORCRGXrCj4DCP+lvkUDLUHx2WN+1ivZkOfQ==
|
||||
"""
|
||||
|
||||
class TestRedhat(unittest.TestCase):
|
||||
def test_RsaPublicKeyToSshRsa(self):
|
||||
OSUtil = RedhatOSUtil()
|
||||
sshRsaPublicKey = OSUtil.RsaPublicKeyToSshRsa(TestPublicKey)
|
||||
self.assertEquals(Expected, sshRsaPublicKey)
|
||||
ssh_rsa_pubkey = OSUtil.asn1_to_ssh_rsa(test_pubkey)
|
||||
self.assertEquals(expected_ssh_rsa_pubkey, ssh_rsa_pubkey)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
+12
-12
@@ -23,11 +23,11 @@ from tests.tools import *
|
||||
import unittest
|
||||
import azurelinuxagent.distro.default.resourceDisk as rdh
|
||||
import azurelinuxagent.logger as logger
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
|
||||
#logger.LoggerInit("/dev/null", "/dev/stdout")
|
||||
|
||||
MockGPTOutput="""
|
||||
gpt_output_sample="""
|
||||
Model: Msft Virtual Disk (scsi)
|
||||
Disk /dev/sda: 32.2GB
|
||||
Sector size (logical/physical): 512B/4096B
|
||||
@@ -40,24 +40,24 @@ Number Start End Size Type File system Flags
|
||||
|
||||
class TestResourceDisk(unittest.TestCase):
|
||||
|
||||
@Mockup(rdh.OSUtil, 'DeviceForIdePort', MockFunc(retval='foo'))
|
||||
@Mockup(rdh.shellutil, 'RunGetOutput', MockFunc(retval=(0, MockGPTOutput)))
|
||||
@Mockup(rdh.shellutil, 'Run', MockFunc(retval=0))
|
||||
@mock(rdh.OSUtil, 'device_for_ide_port', MockFunc(retval='foo'))
|
||||
@mock(rdh.shellutil, 'run_get_output', MockFunc(retval=(0, gpt_output_sample)))
|
||||
@mock(rdh.shellutil, 'run', MockFunc(retval=0))
|
||||
def test_mountGPT(self):
|
||||
handler = rdh.ResourceDiskHandler()
|
||||
handler.mountResourceDisk('/tmp/foo', 'ext4')
|
||||
handler.mount_resource_disk('/tmp/foo', 'ext4')
|
||||
|
||||
@Mockup(rdh.OSUtil, 'DeviceForIdePort', MockFunc(retval='foo'))
|
||||
@Mockup(rdh.shellutil, 'RunGetOutput', MockFunc(retval=(0, "")))
|
||||
@Mockup(rdh.shellutil, 'Run', MockFunc(retval=0))
|
||||
@mock(rdh.OSUtil, 'device_for_ide_port', MockFunc(retval='foo'))
|
||||
@mock(rdh.shellutil, 'run_get_output', MockFunc(retval=(0, "")))
|
||||
@mock(rdh.shellutil, 'run', MockFunc(retval=0))
|
||||
def test_mountMBR(self):
|
||||
handler = rdh.ResourceDiskHandler()
|
||||
handler.mountResourceDisk('/tmp/foo', 'ext4')
|
||||
handler.mount_resource_disk('/tmp/foo', 'ext4')
|
||||
|
||||
@Mockup(rdh.shellutil, 'Run', MockFunc(retval=0))
|
||||
@mock(rdh.shellutil, 'run', MockFunc(retval=0))
|
||||
def test_createSwapSpace(self):
|
||||
handler = rdh.ResourceDiskHandler()
|
||||
handler.createSwapSpace('/tmp/foo', 512)
|
||||
handler.create_swap_space('/tmp/foo', 512)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -33,32 +33,32 @@ import azurelinuxagent.logger as logger
|
||||
class TestHttpOperations(unittest.TestCase):
|
||||
|
||||
def test_parse_url(self):
|
||||
host, port, secure, relativeUrl = restutil._ParseUrl("http://abc.def/ghi#hash?jkl=mn")
|
||||
host, port, secure, rel_uri = restutil._parse_url("http://abc.def/ghi#hash?jkl=mn")
|
||||
self.assertEquals("abc.def", host)
|
||||
self.assertEquals("/ghi#hash?jkl=mn", relativeUrl)
|
||||
self.assertEquals("/ghi#hash?jkl=mn", rel_uri)
|
||||
|
||||
host, port, secure, relativeUrl = restutil._ParseUrl("http://abc.def/")
|
||||
host, port, secure, rel_uri = restutil._parse_url("http://abc.def/")
|
||||
self.assertEquals("abc.def", host)
|
||||
self.assertEquals("/", relativeUrl)
|
||||
self.assertEquals("/", rel_uri)
|
||||
self.assertEquals(False, secure)
|
||||
|
||||
host, port, secure, relativeUrl = restutil._ParseUrl("https://abc.def/ghi?jkl=mn")
|
||||
host, port, secure, rel_uri = restutil._parse_url("https://abc.def/ghi?jkl=mn")
|
||||
self.assertEquals(True, secure)
|
||||
|
||||
host, port, secure, relativeUrl = restutil._ParseUrl("http://abc.def:80/")
|
||||
host, port, secure, rel_uri = restutil._parse_url("http://abc.def:80/")
|
||||
self.assertEquals("abc.def", host)
|
||||
|
||||
def _test_http_get(self):
|
||||
resp = restutil.HttpGet("http://httpbin.org/get").read()
|
||||
resp = restutil.http_get("http://httpbin.org/get").read()
|
||||
self.assertNotEquals(None, resp)
|
||||
|
||||
msg = str(uuid.uuid4())
|
||||
resp = restutil.HttpGet("http://httpbin.org/get", {"x-abc":msg}).read()
|
||||
resp = restutil.http_get("http://httpbin.org/get", {"x-abc":msg}).read()
|
||||
self.assertNotEquals(None, resp)
|
||||
self.assertTrue(msg in resp)
|
||||
|
||||
def _test_https_get(self):
|
||||
resp = restutil.HttpGet("https://httpbin.org/get").read()
|
||||
resp = restutil.http_get("https://httpbin.org/get").read()
|
||||
self.assertNotEquals(None, resp)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -25,7 +25,7 @@ import unittest
|
||||
import os
|
||||
import azurelinuxagent.protocol.v1 as v1
|
||||
|
||||
SharedConfigSample="""
|
||||
shared_config_sample="""
|
||||
|
||||
<SharedConfig version="1.0.0.0" goalStateIncarnation="1">
|
||||
<Deployment name="db00a7755a5e4e8a8fe4b19bc3b330c3" guid="{ce5a036f-5c93-40e7-8adf-2613631008ab}" incarnation="2">
|
||||
@@ -73,7 +73,7 @@ SharedConfigSample="""
|
||||
|
||||
class TestSharedConfig(unittest.TestCase):
|
||||
def test_sharedconfig(self):
|
||||
sharedconfig = v1.SharedConfig(SharedConfigSample)
|
||||
shared_conf = v1.SharedConfig(shared_config_sample)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -26,23 +26,14 @@ import os
|
||||
import azurelinuxagent.utils.shellutil as shellutil
|
||||
import test
|
||||
|
||||
class TestRunCmd(unittest.TestCase):
|
||||
class TestrunCmd(unittest.TestCase):
|
||||
def test_run_get_output(self):
|
||||
output = shellutil.RunGetOutput("ls /")
|
||||
output = shellutil.run_get_output("ls /")
|
||||
self.assertNotEquals(None, output)
|
||||
self.assertEquals(0, output[0])
|
||||
|
||||
err = shellutil.RunGetOutput("ls /not-exists")
|
||||
err = shellutil.run_get_output("ls /not-exists")
|
||||
self.assertNotEquals(0, err[0])
|
||||
|
||||
def test_run_send_stdin(self):
|
||||
test_sh = os.path.join(env.test_root, "read_stdin.sh")
|
||||
|
||||
output = shellutil.RunSendStdin(test_sh, "y")
|
||||
self.assertEquals(0, output[0])
|
||||
|
||||
output = shellutil.RunSendStdin(test_sh, "n")
|
||||
self.assertEquals(1, output[0])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -27,9 +27,9 @@ import azurelinuxagent.utils.textutil as textutil
|
||||
import test
|
||||
|
||||
class TestTextUtil(unittest.TestCase):
|
||||
def test_GetPasswordHash(self):
|
||||
passwdHash = textutil.GetPasswordHash("asdf", True, 6, 10)
|
||||
self.assertNotEquals(None, passwdHash)
|
||||
def test_get_password_hash(self):
|
||||
password_hash = textutil.gen_password_hash("asdf", True, 6, 10)
|
||||
self.assertNotEquals(None, password_hash)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
+71
-71
@@ -29,96 +29,96 @@ import httplib
|
||||
import azurelinuxagent.logger as logger
|
||||
import azurelinuxagent.protocol.v1 as v1
|
||||
from test_version import VersionInfoSample
|
||||
from test_goalstate import GoalStateSample
|
||||
from test_hostingenv import HostingEnvSample
|
||||
from test_sharedconfig import SharedConfigSample
|
||||
from test_certificates import CertificatesSample, TransportCert
|
||||
from test_extensionsconfig import ExtensionsConfigSample, ManifestSample
|
||||
from test_goalstate import goal_state_sample
|
||||
from test_hostingenv import hosting_env_sample
|
||||
from test_sharedconfig import shared_config_sample
|
||||
from test_certificates import certs_sample, transport_cert
|
||||
from test_extensionsconfig import ext_conf_sample, manifest_sample
|
||||
|
||||
#logger.LoggerInit("/dev/stdout", "/dev/null", verbose=True)
|
||||
#logger.LoggerInit("/dev/stdout", "/dev/null", verbose=False)
|
||||
|
||||
def MockFetchUri(url, headers=None, chkProxy=False):
|
||||
def mock_fetch_uri(url, headers=None, chk_proxy=False):
|
||||
content = None
|
||||
if "versions" in url:
|
||||
content = VersionInfoSample
|
||||
elif "goalstate" in url:
|
||||
content = GoalStateSample
|
||||
content = goal_state_sample
|
||||
elif "hostingenvuri" in url:
|
||||
content = HostingEnvSample
|
||||
content = hosting_env_sample
|
||||
elif "sharedconfiguri" in url:
|
||||
content = SharedConfigSample
|
||||
content = shared_config_sample
|
||||
elif "certificatesuri" in url:
|
||||
content = CertificatesSample
|
||||
content = certs_sample
|
||||
elif "extensionsconfiguri" in url:
|
||||
content = ExtensionsConfigSample
|
||||
content = ext_conf_sample
|
||||
elif "manifest.xml" in url:
|
||||
content = ManifestSample
|
||||
content = manifest_sample
|
||||
else:
|
||||
raise Exception("Bad url {0}".format(url))
|
||||
return content
|
||||
|
||||
def MockFetchManifest(uris):
|
||||
return ManifestSample
|
||||
def mock_fetch_manifest(uris):
|
||||
return manifest_sample
|
||||
|
||||
def MockFetchCache(filePath):
|
||||
def mock_fetch_cache(file_path):
|
||||
content = None
|
||||
if "Incarnation" in filePath:
|
||||
if "Incarnation" in file_path:
|
||||
content = 1
|
||||
elif "GoalState" in filePath:
|
||||
content = GoalStateSample
|
||||
elif "HostingEnvironmentConfig" in filePath:
|
||||
content = HostingEnvSample
|
||||
elif "SharedConfig" in filePath:
|
||||
content = SharedConfigSample
|
||||
elif "Certificates" in filePath:
|
||||
content = CertificatesSample
|
||||
elif "TransportCert" in filePath:
|
||||
content = TransportCert
|
||||
elif "ExtensionsConfig" in filePath:
|
||||
content = ExtensionsConfigSample
|
||||
elif "manifest" in filePath:
|
||||
content = ManifestSample
|
||||
elif "GoalState" in file_path:
|
||||
content = goal_state_sample
|
||||
elif "HostingEnvironmentConfig" in file_path:
|
||||
content = hosting_env_sample
|
||||
elif "SharedConfig" in file_path:
|
||||
content = shared_config_sample
|
||||
elif "Certificates" in file_path:
|
||||
content = certs_sample
|
||||
elif "TransportCert" in file_path:
|
||||
content = transport_cert
|
||||
elif "ExtensionsConfig" in file_path:
|
||||
content = ext_conf_sample
|
||||
elif "manifest" in file_path:
|
||||
content = manifest_sample
|
||||
else:
|
||||
raise Exception("Bad filepath {0}".format(filePath))
|
||||
raise Exception("Bad filepath {0}".format(file_path))
|
||||
return content
|
||||
|
||||
class TestWireClint(unittest.TestCase):
|
||||
|
||||
@Mockup(v1, '_fetchCache', MockFetchCache)
|
||||
def testGet(self):
|
||||
@mock(v1, '_fetch_cache', mock_fetch_cache)
|
||||
def test_get(self):
|
||||
os.chdir('/tmp')
|
||||
client = v1.WireClient("foobar")
|
||||
goalState = client.getGoalState()
|
||||
goalState = client.get_goal_state()
|
||||
self.assertNotEquals(None, goalState)
|
||||
hostingEnv = client.getHostingEnv()
|
||||
hostingEnv = client.get_hosting_env()
|
||||
self.assertNotEquals(None, hostingEnv)
|
||||
sharedConfig = client.getSharedConfig()
|
||||
sharedConfig = client.get_shared_conf()
|
||||
self.assertNotEquals(None, sharedConfig)
|
||||
extensionsConfig = client.getExtensionsConfig()
|
||||
extensionsConfig = client.get_ext_conf()
|
||||
self.assertNotEquals(None, extensionsConfig)
|
||||
|
||||
|
||||
@Mockup(v1, '_fetchCache', MockFetchCache)
|
||||
def testGetHeaderWithCert(self):
|
||||
@mock(v1, '_fetch_cache', mock_fetch_cache)
|
||||
def test_get_head_for_cert(self):
|
||||
client = v1.WireClient("foobar")
|
||||
header = client.getHeaderWithCert()
|
||||
header = client.get_header_for_cert()
|
||||
self.assertNotEquals(None, header)
|
||||
|
||||
@Mockup(v1.WireClient, 'getHeaderWithCert', MockFunc())
|
||||
@Mockup(v1, '_fetchUri', MockFetchUri)
|
||||
@Mockup(v1.fileutil, 'SetFileContents', MockFunc())
|
||||
def testUpdateGoalState(self):
|
||||
@mock(v1.WireClient, 'get_header_for_cert', MockFunc())
|
||||
@mock(v1, '_fetch_uri', mock_fetch_uri)
|
||||
@mock(v1.fileutil, 'write_file', MockFunc())
|
||||
def test_update_goal_state(self):
|
||||
client = v1.WireClient("foobar")
|
||||
client.updateGoalState()
|
||||
goalState = client.getGoalState()
|
||||
self.assertNotEquals(None, goalState)
|
||||
hostingEnv = client.getHostingEnv()
|
||||
self.assertNotEquals(None, hostingEnv)
|
||||
sharedConfig = client.getSharedConfig()
|
||||
self.assertNotEquals(None, sharedConfig)
|
||||
extensionsConfig = client.getExtensionsConfig()
|
||||
self.assertNotEquals(None, extensionsConfig)
|
||||
client.update_goal_state()
|
||||
goal_state = client.get_goal_state()
|
||||
self.assertNotEquals(None, goal_state)
|
||||
hosting_env = client.get_hosting_env()
|
||||
self.assertNotEquals(None, hosting_env)
|
||||
shared_config = client.get_shared_conf()
|
||||
self.assertNotEquals(None, shared_config)
|
||||
ext_conf = client.get_ext_conf()
|
||||
self.assertNotEquals(None, ext_conf)
|
||||
|
||||
class MockResp(object):
|
||||
def __init__(self, status):
|
||||
@@ -126,40 +126,40 @@ class MockResp(object):
|
||||
|
||||
class TestStatusBlob(unittest.TestCase):
|
||||
def testToJson(self):
|
||||
vmStatus = v1.VMStatus()
|
||||
statusBlob = v1.StatusBlob(vmStatus)
|
||||
self.assertNotEquals(None, statusBlob.toJson())
|
||||
vm_status = v1.VMStatus()
|
||||
status_blob = v1.StatusBlob(vm_status)
|
||||
self.assertNotEquals(None, status_blob.toJson())
|
||||
|
||||
@Mockup(v1.restutil, 'HttpPut', MockFunc(retval=MockResp(httplib.CREATED)))
|
||||
@Mockup(v1.restutil, 'HttpHead', MockFunc(retval=MockResp(httplib.OK)))
|
||||
@mock(v1.restutil, 'http_put', MockFunc(retval=MockResp(httplib.CREATED)))
|
||||
@mock(v1.restutil, 'http_head', MockFunc(retval=MockResp(httplib.OK)))
|
||||
def test_put_page_blob(self):
|
||||
vmStatus = v1.VMStatus()
|
||||
statusBlob = v1.StatusBlob(vmStatus)
|
||||
vm_status = v1.VMStatus()
|
||||
status_blob = v1.StatusBlob(vm_status)
|
||||
data = ['a'] * 100
|
||||
statusBlob.putPageBlob("http://foo.bar", data)
|
||||
status_blob.put_page_blob("http://foo.bar", data)
|
||||
|
||||
class TestConvert(unittest.TestCase):
|
||||
def test_status(self):
|
||||
vmStatus = v1.VMStatus()
|
||||
handlerStatus = v1.ExtensionHandlerStatus()
|
||||
vm_status = v1.VMStatus()
|
||||
handler_status = v1.ExtensionHandlerStatus()
|
||||
substatus = v1.ExtensionSubStatus()
|
||||
extStatus = v1.ExtensionStatus()
|
||||
ext_status = v1.ExtensionStatus()
|
||||
|
||||
vmStatus.extensionHandlers.append(handlerStatus)
|
||||
v1.vm_status_to_v1(vmStatus)
|
||||
vm_status.extensionHandlers.append(handler_status)
|
||||
v1.vm_status_to_v1(vm_status)
|
||||
|
||||
handlerStatus.extensionStatusList.append(extStatus)
|
||||
v1.vm_status_to_v1(vmStatus)
|
||||
handler_status.extensionStatusList.append(ext_status)
|
||||
v1.vm_status_to_v1(vm_status)
|
||||
|
||||
extStatus.substatusList.append(substatus)
|
||||
v1.vm_status_to_v1(vmStatus)
|
||||
ext_status.substatusList.append(substatus)
|
||||
v1.vm_status_to_v1(vm_status)
|
||||
|
||||
def test_param(self):
|
||||
param = v1.TelemetryEventParam()
|
||||
event = v1.TelemetryEvent()
|
||||
event.parameters.append(param)
|
||||
|
||||
v1.event_to_xml(event)
|
||||
v1.event_to_v1(event)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -42,11 +42,11 @@ VersionInfoSample="""\
|
||||
class TestVersionInfo(unittest.TestCase):
|
||||
def test_version_info(self):
|
||||
config = v1.VersionInfo(VersionInfoSample)
|
||||
self.assertEquals("2012-11-30", config.getPreferred())
|
||||
self.assertNotEquals(None, config.getSupported())
|
||||
self.assertEquals(2, len(config.getSupported()))
|
||||
self.assertEquals("2010-12-15", config.getSupported()[0])
|
||||
self.assertEquals("2010-28-10", config.getSupported()[1])
|
||||
self.assertEquals("2012-11-30", config.get_preferred())
|
||||
self.assertNotEquals(None, config.get_supported())
|
||||
self.assertEquals(2, len(config.get_supported()))
|
||||
self.assertEquals("2010-12-15", config.get_supported()[0])
|
||||
self.assertEquals("2010-28-10", config.get_supported()[1])
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
+8
-10
@@ -21,7 +21,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
from azurelinuxagent.utils.osutil import OSUtil
|
||||
from azurelinuxagent.utils.osutil import OSUTIL
|
||||
|
||||
parent = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent)
|
||||
@@ -31,9 +31,9 @@ def simple_file_grep(file_path, search_str):
|
||||
if search_str in line:
|
||||
return line
|
||||
|
||||
def Mockup(target, name, mock):
|
||||
def Decorator(func):
|
||||
def Wrapper(*args, **kwargs):
|
||||
def mock(target, name, mock):
|
||||
def decorator(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
origin = getattr(target, name)
|
||||
setattr(target, name, mock)
|
||||
try:
|
||||
@@ -43,8 +43,8 @@ def Mockup(target, name, mock):
|
||||
finally:
|
||||
setattr(target, name, origin)
|
||||
return result
|
||||
return Wrapper
|
||||
return Decorator
|
||||
return wrapper
|
||||
return decorator
|
||||
|
||||
class MockFunc(object):
|
||||
def __init__(self, name='', retval=None):
|
||||
@@ -57,9 +57,7 @@ class MockFunc(object):
|
||||
self.kwargs = kwargs
|
||||
return self.retval
|
||||
|
||||
def Dummy():
|
||||
pass
|
||||
|
||||
#Mock osutil so that the test of other part will be os unrelated
|
||||
OSUtil.GetLibDir = MockFunc(retval='/tmp')
|
||||
OSUtil.GetExtLogDir = MockFunc(retval='/tmp/log')
|
||||
OSUTIL.get_lib_dir = MockFunc(retval='/tmp')
|
||||
OSUTIL.get_ext_log_dir = MockFunc(retval='/tmp/log')
|
||||
|
||||
Reference in New Issue
Block a user